diff --git a/README.md b/README.md index 916e5200b29841028652c861c49dbb3650baea3c..ef5bdc66ef03131318e1dde627e0224cca9137fd 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ ----------------- -| **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | -|-----------------|---------------------|------------------|-------------------|---------------| -| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | + +| **`Documentation`** | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | +|-----------------|---------------------|------------------|-------------------|---------------|---------------| +| [![Documentation](https://img.shields.io/badge/api-reference-blue.svg)](https://www.tensorflow.org/api_docs/) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -21,20 +22,6 @@ organization for the purposes of conducting machine learning and deep neural networks research. The system is general enough to be applicable in a wide variety of other domains, as well. -**If you want to contribute to TensorFlow, be sure to review the [contribution -guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's -[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to -uphold this code.** - -**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see -[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions -and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** - -The TensorFlow project strives to abide by generally accepted best practices in open-source software development: - -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) - ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -75,6 +62,22 @@ $ python >>> sess.close() ``` +## Contribution guidelines + +**If you want to contribute to TensorFlow, be sure to review the [contribution +guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's +[code of conduct](CODE_OF_CONDUCT.md). By participating, you are expected to +uphold this code.** + +**We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for +tracking requests and bugs. So please see +[TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions +and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** + +The TensorFlow project strives to abide by generally accepted best practices in open-source software development: + +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + ## For more information * [TensorFlow Website](https://www.tensorflow.org) diff --git a/RELEASE.md b/RELEASE.md index 0720a8c639f8ab87214b11f6a8092b432b916853..6f54dee58f75c29a16545ba25de12fe059baf1eb 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -21,7 +21,7 @@ newcomers. * Other: * Add `tf.contrib.distributions.Kumaraswamy`. * `RetryingFileSystem::FlushCaches()` calls the base FileSystem's `FlushCaches()`. - * Add auto_correlation to distributions. + * Add `auto_correlation` to distributions. * Add `tf.contrib.distributions.Autoregressive`. * Add SeparableConv1D layer. * Add convolutional Flipout layers. @@ -31,12 +31,12 @@ newcomers. * Output variance over trees predictions for classifications tasks. * For `pt` and `eval` commands, allow writing tensor values to filesystem as numpy files. * gRPC: Propagate truncated errors (instead of returning gRPC internal error). - * Augment parallel_interleave to support 2 kinds of prefetching. + * Augment `parallel_interleave` to support 2 kinds of prefetching. * Improved XLA support for C64-related ops log, pow, atan2, tanh. * Add probabilistic convolutional layers. ## API Changes -* Introducing prepare_variance boolean with default setting to False for backward compatibility. +* Introducing `prepare_variance` boolean with default setting to False for backward compatibility. * Move `layers_dense_variational_impl.py` to `layers_dense_variational.py`. ## Known Bugs @@ -96,27 +96,6 @@ Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田 * Starting from 1.6 release, our prebuilt binaries will use AVX instructions. This may break TF on older CPUs. -## Known Bugs -* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or - `CUDA_ILLEGAL_ADDRESS` failures. - - Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 - and CUDA 9.1 sometimes does not properly compute the carry bit when - decomposing 64-bit address calculations with large offsets (e.g. `load [x + - large_constant]`) into 32-bit arithmetic in SASS. - - As a result, these versions of `ptxas` miscompile most XLA programs which use - more than 4GB of temp memory. This results in garbage results and/or - `CUDA_ERROR_ILLEGAL_ADDRESS` failures. - - A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a - fix for CUDA 9.0.x. Until the fix is available, the only workaround is to - [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x - or disable XLA:GPU. - - TensorFlow will print a warning if you use XLA:GPU with a known-bad version of - CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. - ## Major Features And Improvements * [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) preview version is now available. diff --git a/configure b/configure index 9c21d2b03a27714f05094667691e74c16fa89f35..66b66ba54ed68a9aa0ee556f84f68c3a83a495ab 100755 --- a/configure +++ b/configure @@ -8,7 +8,8 @@ if [ -z "$PYTHON_BIN_PATH" ]; then fi # Set all env variables -"$PYTHON_BIN_PATH" configure.py +CONFIGURE_DIR=$(dirname "$0") +"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@" echo "Configuration finished" diff --git a/configure.py b/configure.py index 68c9bbfb1c82418a23229e98e0e5f16cc504acc7..2410cf7e07c1fa9bf61f7e3c473e11fc1e699256 100644 --- a/configure.py +++ b/configure.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import argparse import errno import os import platform @@ -32,10 +33,6 @@ except ImportError: from distutils.spawn import find_executable as which # pylint: enable=g-import-not-at-top -_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), - '.tf_configure.bazelrc') -_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), - 'WORKSPACE') _DEFAULT_CUDA_VERSION = '9.0' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' @@ -51,6 +48,11 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] _DEFAULT_PROMPT_ASK_ATTEMPTS = 10 +_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) +_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' +_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) +_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') + class UserInputError(Exception): pass @@ -119,22 +121,6 @@ def sed_in_place(filename, old, new): f.write(newdata) -def remove_line_with(filename, token): - """Remove lines that contain token from file. - - Args: - filename: string for filename. - token: string token to check if to remove a line from file or not. - """ - with open(filename, 'r') as f: - filedata = f.read() - - with open(filename, 'w') as f: - for line in filedata.strip().split('\n'): - if token not in line: - f.write(line + '\n') - - def write_to_bazelrc(line): with open(_TF_BAZELRC, 'a') as f: f.write(line + '\n') @@ -245,25 +231,30 @@ def setup_python(environ_cp): environ_cp['PYTHON_BIN_PATH'] = python_bin_path # Write tools/python_bin_path.sh - with open('tools/python_bin_path.sh', 'w') as f: + with open(os.path.join( + _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path) -def reset_tf_configure_bazelrc(): +def reset_tf_configure_bazelrc(workspace_path): """Reset file that contains customized config settings.""" open(_TF_BAZELRC, 'w').close() - - home = os.path.expanduser('~') - if not os.path.exists('.bazelrc'): - if os.path.exists(os.path.join(home, '.bazelrc')): - with open('.bazelrc', 'a') as f: - f.write('import %s/.bazelrc\n' % home.replace('\\', '/')) + bazelrc_path = os.path.join(workspace_path, '.bazelrc') + + data = [] + if os.path.exists(bazelrc_path): + with open(bazelrc_path, 'r') as f: + data = f.read().splitlines() + with open(bazelrc_path, 'w') as f: + for l in data: + if _TF_BAZELRC_FILENAME in l: + continue + f.write('%s\n' % l) + if is_windows(): + tf_bazelrc_path = _TF_BAZELRC.replace("\\", "/") else: - open('.bazelrc', 'w').close() - - remove_line_with('.bazelrc', 'tf_configure') - with open('.bazelrc', 'a') as f: - f.write('import %workspace%/.tf_configure.bazelrc\n') + tf_bazelrc_path = _TF_BAZELRC + f.write('import %s\n' % tf_bazelrc_path) def cleanup_makefile(): @@ -271,7 +262,8 @@ def cleanup_makefile(): These files could interfere with Bazel parsing. """ - makefile_download_dir = 'tensorflow/contrib/makefile/downloads' + makefile_download_dir = os.path.join( + _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') if os.path.isdir(makefile_download_dir): for root, _, filenames in os.walk(makefile_download_dir): for f in filenames: @@ -456,7 +448,7 @@ def check_bazel_version(min_version): if which('bazel') is None: print('Cannot find bazel. Please install bazel.') sys.exit(0) - curr_version = run_shell(['bazel', '--batch', 'version']) + curr_version = run_shell(['bazel', '--batch', '--bazelrc=/dev/null', 'version']) for line in curr_version.split('\n'): if 'Build label: ' in line: @@ -502,7 +494,8 @@ def set_cc_opt_flags(environ_cp): for opt in cc_opt_flags.split(): write_to_bazelrc('build:opt --copt=%s' % opt) # It should be safe on the same build host. - write_to_bazelrc('build:opt --host_copt=-march=native') + if not is_ppc64le(): + write_to_bazelrc('build:opt --host_copt=-march=native') write_to_bazelrc('build:opt --define with_default_optimizations=true') # TODO(mikecase): Remove these default defines once we are able to get # TF Lite targets building without them. @@ -916,7 +909,7 @@ def set_tf_cudnn_version(environ_cp): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) - tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version) ,1) + tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1) default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' @@ -1081,7 +1074,7 @@ def set_tf_tensorrt_install_path(environ_cp): break # Reset and Retry - if len(possible_files): + if possible_files: print('TensorRT libraries found in one the following directories', 'are not compatible with selected cuda and cudnn installations') print(trt_install_path) @@ -1090,7 +1083,8 @@ def set_tf_tensorrt_install_path(environ_cp): if search_result: print(libnvinfer_path_from_ldconfig) else: - print('Invalid path to TensorRT. None of the following files can be found:') + print( + 'Invalid path to TensorRT. None of the following files can be found:') print(trt_install_path) print(os.path.join(trt_install_path, 'lib')) print(os.path.join(trt_install_path, 'lib64')) @@ -1231,7 +1225,7 @@ def set_host_c_compiler(environ_cp): environ_cp, var_name='HOST_C_COMPILER', var_default=default_c_host_compiler, - ask_for_var=('Please specify which C compiler should be used as the host' + ask_for_var=('Please specify which C compiler should be used as the host ' 'C compiler.'), check_success=os.path.exists, error_msg='Invalid C compiler path. %s cannot be found.', @@ -1375,13 +1369,20 @@ def config_info_line(name, help_text): def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", + type=str, + default=_TF_WORKSPACE_ROOT, + help="The absolute path to your active Bazel workspace.") + args = parser.parse_args() + # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. environ_cp = dict(os.environ) check_bazel_version('0.5.4') - reset_tf_configure_bazelrc() + reset_tf_configure_bazelrc(args.workspace) cleanup_makefile() setup_python(environ_cp) @@ -1436,8 +1437,10 @@ def main(): if is_linux(): set_tf_tensorrt_install_path(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) - if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get('LD_LIBRARY_PATH') != '1': - write_action_env_to_bazelrc('LD_LIBRARY_PATH', environ_cp.get('LD_LIBRARY_PATH')) + if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( + 'LD_LIBRARY_PATH') != '1': + write_action_env_to_bazelrc('LD_LIBRARY_PATH', + environ_cp.get('LD_LIBRARY_PATH')) set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': diff --git a/tensorflow/BUILD b/tensorflow/BUILD index dc995d231d3e591771f801e28024a76610cdba26..3828ee0ddbc9c4c679d5358db4579c312d3e2524 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -787,6 +787,7 @@ tf_cc_shared_object( }), deps = [ "//tensorflow/c:c_api", + "//tensorflow/c:c_api_experimental", "//tensorflow/c:exported_symbols.lds", "//tensorflow/c:version_script.lds", "//tensorflow/c/eager:c_api", diff --git a/tensorflow/SECURITY.md b/tensorflow/SECURITY.md index 074eed2951526d53ab62515b7b869569a9708299..fea24b273920885ba8a1ae96aafbf7710df46e1f 100644 --- a/tensorflow/SECURITY.md +++ b/tensorflow/SECURITY.md @@ -233,7 +233,7 @@ v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= ### Known vulnerabilities -| Type | Versions affected | Reported by | Additional Information | -|------|:-----------------:|---------------------------------------| -| out of bounds read| <=1.4 | @zhangbo5891001 | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | +| Type | Versions affected | Reported by | Additional Information | +|-------------------|:-----------------:|--------------------|-----------------------------| +| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 25a994be3ebd844a26ac8936e78058bb123a3f75..5dfb743681255d8c03e91ea43fd441d94fdee59d 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -6,17 +6,12 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", "tf_cc_test", + "tf_cuda_cc_test", "tf_copts", "tf_cuda_library", "tf_custom_op_library", ) -# For platform specific build config -load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_kernel_tests_linkstatic", -) - # ----------------------------------------------------------------------------- # Public targets @@ -33,7 +28,11 @@ filegroup( "*.cc", "*.h", ], - exclude = ["*test*"], + exclude = [ + "c_api_experimental.cc", + "c_api_experimental.h", + "*test*", + ], ), visibility = ["//visibility:public"], ) @@ -100,6 +99,24 @@ tf_cuda_library( }), ) +tf_cuda_library( + name = "c_api_experimental", + srcs = [ + "c_api_experimental.cc", + ], + hdrs = [ + "c_api_experimental.h", + ], + copts = tf_copts(), + visibility = ["//visibility:public"], + deps = [ + ":c_api", + ":c_api_internal", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + "//tensorflow/core:protos_all_cc", + ], +) + exports_files( [ "version_script.lds", @@ -147,7 +164,7 @@ tf_cuda_library( ], deps = [ ":c_api", - "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", + ":c_api_experimental", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:session_options", @@ -155,7 +172,7 @@ tf_cuda_library( ], ) -tf_cc_test( +tf_cuda_cc_test( name = "c_api_test", size = "small", srcs = ["c_api_test.cc"], diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index 6d30905a1afaf5845b66d2e1f956055727c4b38b..ad592ef70961ef427bfe9fd322a82bd64df7f9f1 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -1287,11 +1287,12 @@ TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); typedef struct TF_Session TF_Session; -// Return a new execution session with the associated graph, or NULL on error. +// Return a new execution session with the associated graph, or NULL on +// error. Does not take ownership of any input parameters. // -// *graph must be a valid graph (not deleted or nullptr). This function will -// prevent the graph from being deleted until TF_DeleteSession() is called. -// Does not take ownership of opts. +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be +// kept alive for the lifetime of the returned TF_Session. New nodes can still +// be added to `graph` after this call. TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, TF_Status* status); diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc new file mode 100644 index 0000000000000000000000000000000000000000..be7f85a5bb06dce84579b109d506ded049042b50 --- /dev/null +++ b/tensorflow/c/c_api_experimental.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/c_api_experimental.h" + +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/core/protobuf/config.pb.h" + +void TF_EnableXLACompilation(TF_SessionOptions* options, unsigned char enable) { + tensorflow::ConfigProto& config = options->options.config; + auto* optimizer_options = + config.mutable_graph_options()->mutable_optimizer_options(); + if (enable) { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + + // These XLA flags are needed to trigger XLA properly from C (more generally + // non-Python) clients. If this API is called again with `enable` set to + // false, it is safe to keep these flag values as is. + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = true; + flags->tf_xla_min_cluster_size = 1; + } else { + optimizer_options->set_global_jit_level(tensorflow::OptimizerOptions::OFF); + } +} diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h new file mode 100644 index 0000000000000000000000000000000000000000..5a7b007e40aa199889b2d00b2bde5976c19e2966 --- /dev/null +++ b/tensorflow/c/c_api_experimental.h @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_C_API_EXPERIMENTAL_H_ +#define TENSORFLOW_C_C_API_EXPERIMENTAL_H_ + +#include +#include + +#include "tensorflow/c/c_api.h" + +// -------------------------------------------------------------------------- +// Experimental C API for TensorFlow. +// +// The API here is subject to changes in the future. + +// Macro to control visibility of exported symbols in the shared library (.so, +// .dylib, .dll). +// This duplicates the TF_EXPORT macro definition in +// tensorflow/core/platform/macros.h in order to keep this .h file independent +// of any other includes.$a +#ifdef SWIG +#define TF_CAPI_EXPORT +#else +#if defined(COMPILER_MSVC) +#ifdef TF_COMPILE_LIBRARY +#define TF_CAPI_EXPORT __declspec(dllexport) +#else +#define TF_CAPI_EXPORT __declspec(dllimport) +#endif // TF_COMPILE_LIBRARY +#else +#define TF_CAPI_EXPORT __attribute__((visibility("default"))) +#endif // COMPILER_MSVC +#endif // SWIG + +#ifdef __cplusplus +extern "C" { +#endif + +// When `enable` is true, set +// tensorflow.ConfigProto.OptimizerOptions.global_jit_level to ON_1, and also +// set XLA flag values to prepare for XLA compilation. Otherwise set +// global_jit_level to OFF. +// +// This API is syntax sugar over TF_SetConfig(), and is used by clients that +// cannot read/write the tensorflow.ConfigProto proto. +TF_CAPI_EXPORT extern void TF_EnableXLACompilation(TF_SessionOptions* options, + unsigned char enable); + +#ifdef __cplusplus +} /* end extern "C" */ +#endif + +#endif // TENSORFLOW_C_C_API_EXPERIMENTAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 654664166aa55c8ba9fe2678b8f7ec88e5ae5d7f..028f146be31790b211e546978302e81afe26b231 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -57,6 +57,52 @@ static void ExpectHasSubstr(StringPiece s, StringPiece expected) { << "'" << s << "' does not contain '" << expected << "'"; } +// Returns the GPU device name if there is one (with arbitrary tie breaking if +// there are more than one), or "" otherwise. +string GPUDeviceName(TF_Session* session) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + std::unique_ptr list( + TF_SessionListDevices(session, s), TF_DeleteDeviceList); + TF_DeviceList* device_list = list.get(); + + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + const int num_devices = TF_DeviceListCount(device_list); + LOG(INFO) << "There are " << num_devices << " devices."; + for (int i = 0; i < num_devices; ++i) { + const char* device_name = TF_DeviceListName(device_list, i, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + const char* device_type = TF_DeviceListType(device_list, i, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + LOG(INFO) << "Device " << i << " has name " << device_name << ", type " + << device_type; + if (string(device_type) == DEVICE_GPU) { + return device_name; + } + } + // No GPU device found. + return ""; +} + +string GPUDeviceName() { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + std::unique_ptr graph(TF_NewGraph(), + TF_DeleteGraph); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph.get(), opts, s); + TF_DeleteSessionOptions(opts); + + const string gpu_device_name = GPUDeviceName(sess); + TF_DeleteSession(sess, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return gpu_device_name; +} + TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { @@ -134,6 +180,10 @@ TEST(CAPI, MaybeMove) { } TEST(CAPI, LibraryLoadFunctions) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + // Load the library. TF_Status* status = TF_NewStatus(); TF_Library* lib = @@ -923,7 +973,9 @@ TEST(CAPI, Session) { TF_DeleteStatus(s); } -TEST(CAPI, Session_Min_CPU) { +// If `device` is non-empty, run Min op on that device. +// Otherwise run it on the default device (CPU). +void RunMinTest(const string& device, bool use_XLA) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -935,12 +987,14 @@ TEST(CAPI, Session_Min_CPU) { TF_Operation* one = ScalarConst(0, graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - // Add operation. - TF_Operation* min = Min(feed, one, graph, s); + // Create a session for this graph. + CSession csession(graph, s, use_XLA); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - // Create a session for this graph. - CSession csession(graph, s); + if (!device.empty()) { + LOG(INFO) << "Setting op Min on device " << device; + } + TF_Operation* min = MinWithDevice(feed, one, graph, device, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); // Run the graph. @@ -963,44 +1017,24 @@ TEST(CAPI, Session_Min_CPU) { TF_DeleteStatus(s); } -TEST(CAPI, Session_Min_XLA_CPU) { - TF_Status* s = TF_NewStatus(); - TF_Graph* graph = TF_NewGraph(); - - // Make a placeholder operation. - TF_Operation* feed = Placeholder(graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); +TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); } - // Make a constant operation with the scalar "0", for axis. - TF_Operation* one = ScalarConst(0, graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); +TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); } - // Add operation. - TF_Operation* min = Min(feed, one, graph, s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); +TEST(CAPI, Session_Min_GPU) { + const string gpu_device = GPUDeviceName(); + // Skip this test if no GPU is available. + if (gpu_device.empty()) return; - // Create a session for this graph. - CSession csession(graph, s, /*use_XLA=*/true); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + RunMinTest(gpu_device, /*use_XLA=*/false); +} - // Run the graph. - csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}}); - csession.SetOutputs({min}); - csession.Run(s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_Tensor* out = csession.output_tensor(0); - ASSERT_TRUE(out != nullptr); - EXPECT_EQ(TF_INT32, TF_TensorType(out)); - EXPECT_EQ(0, TF_NumDims(out)); // scalar - ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); - int32* output_contents = static_cast(TF_TensorData(out)); - EXPECT_EQ(2, *output_contents); +TEST(CAPI, Session_Min_XLA_GPU) { + const string gpu_device = GPUDeviceName(); + // Skip this test if no GPU is available. + if (gpu_device.empty()) return; - // Clean up - csession.CloseAndDelete(s); - ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - TF_DeleteGraph(graph); - TF_DeleteStatus(s); + RunMinTest(gpu_device, /*use_XLA=*/true); } TEST(CAPI, SessionPRun) { @@ -2145,6 +2179,10 @@ TEST_F(CApiAttributesTest, Errors) { } TEST(TestApiDef, TestCreateApiDef) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + TF_Status* status = TF_NewStatus(); TF_Library* lib = TF_LoadLibrary("tensorflow/c/test_op.so", status); @@ -2175,6 +2213,10 @@ TEST(TestApiDef, TestCreateApiDef) { } TEST(TestApiDef, TestCreateApiDefWithOverwrites) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + TF_Status* status = TF_NewStatus(); TF_Library* lib = TF_LoadLibrary("tensorflow/c/test_op.so", status); diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index 2c5f08d6725dc9736055dfef98b4ed2f9252ed13..3db2852ce6560ba493d60ef54a110161c112d110 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" -#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" +#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" @@ -163,10 +163,14 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, return TF_FinishOperation(desc, s); } +// If `op_device` is non-empty, set the created op on that device. void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name, - TF_Operation** op, bool check) { + TF_Operation** op, const string& op_device, bool check) { TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name); + if (!op_device.empty()) { + TF_SetDevice(desc, op_device.c_str()); + } TF_AddInput(desc, {l, 0}); TF_AddInput(desc, {r, 0}); *op = TF_FinishOperation(desc, s); @@ -176,13 +180,19 @@ void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r, } } -TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, - TF_Status* s, const char* name) { +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name) { TF_Operation* op; - BinaryOpHelper("Min", l, r, graph, s, name, &op, true); + BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true); return op; } +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); @@ -394,19 +404,7 @@ std::vector GetFuncNames(const tensorflow::GraphDef& graph_def) { CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) { TF_SessionOptions* opts = TF_NewSessionOptions(); - tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = - tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); - flags->tf_xla_cpu_global_jit = use_XLA; - if (use_XLA) { - tensorflow::ConfigProto config; - config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); - std::string contents; - contents.resize(config.ByteSizeLong()); - config.SerializeToArray(&contents[0], contents.size()); - TF_SetConfig(opts, contents.data(), contents.size(), s); - } + TF_EnableXLACompilation(opts, use_XLA); session_ = TF_NewSession(graph, opts, s); TF_DeleteSessionOptions(opts); } diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index 805fafae05c4a9e5a8a2044471e86872a7668bda..2a70177c724c569844a5d8ad42b99bed20209946 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -72,6 +72,11 @@ TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name = "min"); +// If `op_device` is non-empty, set the created op on that device. +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name = "min"); + TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name = "neg"); @@ -127,6 +132,8 @@ class CSession { TF_Tensor* output_tensor(int i) { return output_values_[i]; } + TF_Session* mutable_session() { return session_; } + private: void DeleteInputValues(); void ResetOutputValues(); diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 8e834eb99c13d1f26da9f0860897267efc2fd01c..bebb63c7462f2e85343553094824892eddce3ce6 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" @@ -154,17 +155,24 @@ TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) { return static_cast(h->t.dtype()); } -int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); } +int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) { + status->status = tensorflow::Status::OK(); + return h->t.dims(); +} -int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) { +int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index, + TF_Status* status) { + status->status = tensorflow::Status::OK(); return h->t.dim_size(dim_index); } -const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) { - // This might be a bit confusing as a tensor on CPU can sometimes return - // "CPU:0" and sometimes "/job:localhost/replica:0/task:0/cpu:0". - // TODO(ashankar): Figure out which one would be nicer. - return (h->d == nullptr) ? "CPU:0" : h->d->name().c_str(); +const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) { + // TODO(apassos) this will be potentially incorrect in the distributed case as + // our local device will have a name which depends on the ClusterSpec and + // hence will require the context to resolve. + status->status = tensorflow::Status::OK(); + return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0" + : h->d->name().c_str(); } TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) { @@ -296,11 +304,9 @@ void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { // Questionable heuristic ... - // - // Motivation: After an 'op' is placed on GPU because some of its earlier - // inputs are on GPU, we want to keep the 'op' there, even if some later - // inputs of it are not on GPU. - if (IsCPU(op->device) && !IsCPU(h->d)) { + // - If a device was explicitly set on the op, always use that. + // - If not, place on the first non-host device seen. + if (op->device == nullptr && !IsCPU(h->d)) { op->device = h->d; } if (!status->status.ok()) return; @@ -801,6 +807,10 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); + if (ctx->log_device_placement) { + LOG(INFO) << "Executing op " << ndef.op() << " in device " + << device->name(); + } kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); // Knowledge of the implementation of Init (and in-turn // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def @@ -814,6 +824,25 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, delete kernel; return; } + // Update output_dtypes inside `kernel`. + const tensorflow::OpDef* op_def = nullptr; + const tensorflow::FunctionDef* function_def = + ctx->func_lib_def.Find(ndef.op()); + if (function_def != nullptr) { + op_def = &(function_def->signature()); + } + if (op_def == nullptr) { + status->status = OpDefForOp(ndef.op().c_str(), &op_def); + if (!status->status.ok()) { + return; + } + } + tensorflow::DataTypeVector input_dtypes; + status->status = InOutTypesForNode(ndef, *op_def, &input_dtypes, + kernel->output_dtypes()); + if (!status->status.ok()) { + return; + } tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index 7a321b54da343fd2b8912187bc620c1e7456db0c..90cfb7500e26231052b7c942ba6d2aeeafab7dc9 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -119,11 +119,13 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); -TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h, + TF_Status* status); TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, - int dim_index); + int dim_index, + TF_Status* status); TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( - TFE_TensorHandle* h); + TFE_TensorHandle* h, TF_Status* status); TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 7b9f1db02ed9c53a280c7bd1284165cac4fb6353..3356054cd09939b24a1d942c0cced06136e33b85 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -50,7 +50,9 @@ struct TFE_Context { rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), pflr(new tensorflow::ProcessFunctionLibraryRuntime( session->device_mgr, opts.session_options.options.env, - TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} + TF_GRAPH_DEF_VERSION, &func_lib_def, {})), + log_device_placement( + opts.session_options.options.config.log_device_placement()) {} const TFE_ContextDevicePlacementPolicy policy; @@ -88,6 +90,8 @@ struct TFE_Context { std::atomic should_store_metadata{false}; tensorflow::mutex metadata_mu; tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); + + const bool log_device_placement; }; struct TFE_TensorHandle { diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 4a3ecbc0abb16296a84c0d2184dc3fc9f7f3ebb4..00fb7e68d00dd2ef316bf89b8f253cf6c7c63f00 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -932,7 +932,8 @@ TEST(CAPI, Variables) { ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); EXPECT_EQ(TF_FLOAT, TFE_TensorHandleDataType(value_handle)); - EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle)); + EXPECT_EQ(0, TFE_TensorHandleNumDims(value_handle, status)); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); float value = 0.0f; TF_Tensor* t = TFE_TensorHandleResolve(value_handle, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -974,7 +975,8 @@ void BM_ReadVariable(int iters) { CHECK_EQ(1, num_retvals); CHECK(h); CHECK_EQ(TF_FLOAT, TFE_TensorHandleDataType(h)); - CHECK_EQ(0, TFE_TensorHandleNumDims(h)); + CHECK_EQ(0, TFE_TensorHandleNumDims(h, status)); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); h = nullptr; } tensorflow::testing::StopTiming(); diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index f77a937f1ffc2d146224cb3191a5ca127daefc22..4bf24fec2cbceab3da0c6a39a2d68bcda5915de9 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -41,17 +41,26 @@ const uint32 kIsList = 1U << 31; } // namespace +Status OpDefForOp(const char* op_name, const OpDef** op_def) { + const OpRegistrationData* op_reg_data = nullptr; + Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + if (s.ok()) { + *op_def = &op_reg_data->op_def; + } + return s; +} + Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { mutex_lock l(g_op_name_to_attr_type_map_lock); *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); if (*out != nullptr) return Status::OK(); - const OpRegistrationData* op_reg_data = nullptr; - Status s = OpRegistry::Global()->LookUp(op_name, &op_reg_data); + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name, &op_def); if (!s.ok()) return s; std::unique_ptr m(new AttrTypeMap); // TODO(agarwal): Avoid having to create this "registry" at runtime, // perhaps can be done at op registration time? - for (const auto& attr : op_reg_data->op_def.attr()) { + for (const auto& attr : op_def->attr()) { string type = attr.type(); const bool is_list = (type.length() > 6 && type.compare(0, 4, "list") == 0); if (is_list) { diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index 4d20b5244a46fcde2eed0a429dced2a77b86aedd..7fede4dae94f31e662a12757c49680d55118e922 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -39,6 +39,9 @@ namespace tensorflow { // represent the TF_AttrType type of the values in the list. typedef std::unordered_map AttrTypeMap; +// Look up OpDef for `op_name`. +Status OpDefForOp(const char* op_name, const OpDef** op_def); + // Returns the AttrTypeMap for the TensorFlow operation named op_name. Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); @@ -180,12 +183,15 @@ class KernelAndDevice { const OpKernel* kernel() const { return kernel_.get(); } + DataTypeVector* output_dtypes() { return &output_dtypes_; } + private: std::unique_ptr kernel_; Device* device_; FunctionLibraryRuntime* flib_; checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_; Rendezvous* rendez_; + DataTypeVector output_dtypes_; }; } // namespace tensorflow diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index 6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68..f553142d15f476ad2c1af68016a4254ed211b9b2 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -99,4 +99,9 @@ void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { } } +void SetRequireShapeInferenceFns(TF_Graph* graph, bool require) { + mutex_lock l(graph->mu); + graph->refiner.set_require_shape_inference_fns(require); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index aa9d9e06b28c54cb8869eb547d36ee3cb0d4e6b8..542d70f42c2a5df8309a722b32d850dd249e496f 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -37,6 +37,10 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); +// Sets whether ops missing a shape inference function should trigger an +// error. The default is true. +void SetRequireShapeInferenceFns(TF_Graph* graph, bool require); + } // namespace tensorflow #endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index 13a3bba5e6d5ca19ff3f0eca76665ba7d3ab628d..63a67f09f6f7c2b39da8cf082c2a36179014ac6f 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -196,6 +196,70 @@ Status MaxPoolGradV2Helper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("MaxPoolV2", MaxPoolGradV2Helper); +Status MaxPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + MaxPool3DGrad::Attrs grad_attrs; + grad_attrs.DataFormat(data_format); + auto dx = MaxPool3DGrad(scope, op.input(0), op.output(0), grad_inputs[0], + ksize, strides, padding, grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("MaxPool3D", MaxPool3DGradHelper); + +Status AvgPoolGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + internal::AvgPoolGrad::Attrs grad_attrs; + grad_attrs.DataFormat(data_format); + auto dx = + internal::AvgPoolGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool", AvgPoolGradHelper); + +Status AvgPool3DGradHelper(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + std::vector ksize; + std::vector strides; + string padding; + string data_format; + auto attrs = op.output(0).node()->attrs(); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "ksize", &ksize)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "strides", &strides)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "padding", &padding)); + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "data_format", &data_format)); + AvgPool3DGrad::Attrs grad_attrs; + grad_attrs.DataFormat(data_format); + auto dx = AvgPool3DGrad(scope, Shape(scope, op.input(0)), grad_inputs[0], + ksize, strides, padding, grad_attrs); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("AvgPool3D", AvgPool3DGradHelper); + Status LRNGradHelper(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs){ diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index 0cfe5f6e3c49f7c4a3cafbf48ff4e54a0ffd0d47..c4eba7ecb017fe4628140d75a63bc7f0f09deb7f 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -31,8 +31,11 @@ using ops::Elu; using ops::L2Loss; using ops::LogSoftmax; using ops::LRN; +using ops::AvgPool; +using ops::AvgPool3D; using ops::MaxPool; using ops::MaxPoolV2; +using ops::MaxPool3D; using ops::Placeholder; using ops::Relu; using ops::Relu6; @@ -70,9 +73,9 @@ class NNGradTest : public ::testing::Test { // Sets tensor with random values, ensuring that the max value is largest by // a reasonable amount. - // This is an issue for MaxPool and MaxPoolV2, in which perturbations by the - // numeric gradient computation in the gradient checker can change the max - // value if values are too close together. + // This is an issue for MaxPool, MaxPoolV2 and MaxPool3D, in which + // perturbations by the numeric gradient computation in the gradient checker + // can change the max value if values are too close together. template void SetRandomValuesWithBumpedMax(Tensor* tensor) { auto tensor_flat = tensor->flat(); @@ -203,6 +206,41 @@ TEST_F(NNGradTest, MaxPoolGradV2Helper) { RunTest(x, x_init_value, y, y_shape); } +TEST_F(NNGradTest, MaxPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one MaxPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = MaxPool3D(scope_, x, ksize, strides, "VALID"); + Tensor x_init_value = Tensor(DT_FLOAT, x_shape); + SetRandomValuesWithBumpedMax(&x_init_value); + RunTest(x, x_init_value, y, y_shape); +} + +TEST_F(NNGradTest, AvgPoolGradHelper) { + TensorShape x_shape({1, 2, 2, 1}); + TensorShape y_shape({1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool. + const std::vector ksize{1, 2, 2, 1}; + const std::vector strides{1, 2, 2, 1}; + auto y = AvgPool(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + +TEST_F(NNGradTest, AvgPool3DGradHelper) { + TensorShape x_shape({1, 3, 3, 3, 1}); + TensorShape y_shape({1, 1, 1, 1, 1}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + // Setup window and strides so that we only do one AvgPool3D. + const std::vector ksize{1, 3, 3, 3, 1}; + const std::vector strides{1, 3, 3, 3, 1}; + auto y = AvgPool3D(scope_, x, ksize, strides, "SAME"); + RunTest(x, x_shape, y, y_shape); +} + TEST_F(NNGradTest, LRN){ TensorShape x_shape({1, 1, 2, 1}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h index 6077c45c5854fd5812ccb7c91522f93ed4e54883..64edbb5766c3604fbe0f15c2299843718381aa3f 100644 --- a/tensorflow/cc/profiler/profiler.h +++ b/tensorflow/cc/profiler/profiler.h @@ -61,18 +61,18 @@ class Profiler { /// Adds tracing information `run_meta` to profiler. A `run_meta` is /// generated by a TensorFlow session run call. `step` is the key /// to the `run_meta`. When calling ProfileXXX methods, caller can specify - /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// `step` in `options` to selectively profile the corresponding `run_meta`. /// Multiple different `run_meta` can be keyed by the same `step` in order /// to group them together. void AddStep(int64 step, const RunMetadata& run_meta); /// Profiles the model by organizing nodes in graph structure. - /// Each node is an op and the nodes are contected by the op inputs/outputs. + /// Each node is an op and the nodes are connected by the op inputs/outputs. GraphNodeProto ProfileGraph(const Options& options); /// Profiles the model by organizing nodes in name scope structure. /// Each node is an op, and nodes are organized by the ops' name - /// scope, similar to a filesystem tree. + /// scope, similar to a file system tree. /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. GraphNodeProto ProfileNameScope(const Options& options); diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index eb3e632c7bb0f37fa7bdaeed7b0687abf9545718..9dff1be09fede6f65f82c2f36d94be07e781949f 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -224,9 +224,6 @@ def tf_library(name, graph, config, # TODO(cwhipkey): only depend on kernel code that the model actually needed. "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", "//tensorflow/compiler/xla/service/cpu:runtime_matmul", "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index a711319607f4ff2b83aa0ebe50e215b3d0e2258e..af259e0564c885836cf0e49c8b29c6169f059c5a 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -102,12 +102,17 @@ cc_library( cc_library( name = "xla_interpreter_device", srcs = ["xla_interpreter_device.cc"], + visibility = [":friends"], deps = [ + ":jit_compilation_passes", ":xla_device", "//tensorflow/compiler/jit/kernels:xla_launch_op", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/compiler/xla/service:interpreter_plugin", # buildcleaner: keep + "//tensorflow/core:lib", ], - alwayslink = True, + alwayslink = 1, ) cc_library( diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 2614deefd8823dcb8f38e9e22ae4e78145d0d96a..a329451b14a785b17913e3838a6571b62b422804 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -25,8 +25,8 @@ namespace tensorflow { const char* const DEVICE_XLA_INTERPRETER = "XLA_INTERPRETER"; const char* const DEVICE_INTERPRETER_XLA_JIT = "XLA_INTERPRETER_JIT"; -constexpr std::array kExecAllTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL, DT_DOUBLE, DT_INT64}}; +constexpr std::array kExecAllTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; class XlaInterpreterDeviceFactory : public DeviceFactory { public: diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 25e329b6aadbab7219d7120ce5f51b3a6f5884e9..782bf82d4149968d5e5fbfb93bbd4ff1dcd75494 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -479,8 +479,9 @@ tf_xla_py_test( tf_xla_py_test( name = "reverse_sequence_op_test", - size = "small", + size = "medium", srcs = ["reverse_sequence_op_test.py"], + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -639,6 +640,7 @@ tf_xla_py_test( name = "variable_ops_test", size = "small", srcs = ["variable_ops_test.py"], + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -677,6 +679,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "scatter_nd_op_test", + size = "medium", + srcs = ["scatter_nd_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "xla_device_test", size = "small", @@ -801,6 +816,17 @@ tf_library( tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], ) +tf_xla_py_test( + name = "fake_quant_ops_test", + size = "medium", + srcs = ["fake_quant_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe9400ef0f55ca011d4e23ba5d735899ca2e054 --- /dev/null +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -0,0 +1,452 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import googletest + + +class FakeQuantWithMinMaxArgsTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxArgs operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + expected = np.array( + [ + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, expected_nudged_input_max, + expected_nudged_input_max, expected_nudged_input_max + ], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + outputs = array_ops.fake_quant_with_min_max_args( + input_placeholder, + min=input_min, + max=input_max, + num_bits=num_bits, + narrow_range=narrow_range) + result = session.run(outputs, {input_placeholder: inputs}) + self.assertAllCloseAccordingToType( + result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + gradients = np.arange(1, len(inputs) + 1, dtype=np.float32) + expected_backprops = np.array( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + gradient_placeholder = array_ops.placeholder( + dtypes.float32, gradients.shape, name="gradients") + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + outputs = gen_array_ops.fake_quant_with_min_max_args_gradient( + gradient_placeholder, + input_placeholder, + min=input_min, + max=input_max, + num_bits=num_bits, + narrow_range=narrow_range) + backprops = session.run(outputs, { + gradient_placeholder: gradients, + input_placeholder: inputs + }) + self.assertAllCloseAccordingToType( + backprops, + expected_backprops, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxVarsTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxVars operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + expected = np.array( + [ + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, expected_nudged_input_max, + expected_nudged_input_max, expected_nudged_input_max + ], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min") + max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max") + outputs = array_ops.fake_quant_with_min_max_vars( + input_placeholder, + min_placeholder, + max_placeholder, + num_bits=num_bits, + narrow_range=narrow_range) + result = session.run( + outputs, { + input_placeholder: inputs, + min_placeholder: input_min, + max_placeholder: input_max + }) + self.assertAllCloseAccordingToType( + result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + gradients = np.arange(1, len(inputs) + 1, dtype=np.float32) + expected_backprops_wrt_input = np.array( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], + dtype=np.float32) + expected_backprops_wrt_min = 1.0 + 2.0 + expected_backprops_wrt_max = 10.0 + 11.0 + + with self.test_session() as session: + with self.test_scope(): + gradient_placeholder = array_ops.placeholder( + dtypes.float32, gradients.shape, name="gradients") + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min") + max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max") + outputs = array_ops.fake_quant_with_min_max_vars_gradient( + gradient_placeholder, + input_placeholder, + min_placeholder, + max_placeholder, + num_bits=num_bits, + narrow_range=narrow_range) + backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = session.run( + outputs, { + gradient_placeholder: gradients, + input_placeholder: inputs, + min_placeholder: input_min, + max_placeholder: input_max + }) + self.assertAllCloseAccordingToType( + backprops_wrt_input, + expected_backprops_wrt_input, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + self.assertAllCloseAccordingToType( + backprops_wrt_min, + expected_backprops_wrt_min, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + self.assertAllCloseAccordingToType( + backprops_wrt_max, + expected_backprops_wrt_max, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..638946e234daf28dc4a34e6c33fc0f78b8e8699b --- /dev/null +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -0,0 +1,188 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.scatter_nd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +def _AsType(v, vtype): + return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v) + + +def _FlatInnerDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape( + [functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)] + + shape[-ndims + 1:]) + + +def _FlatOuterDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape( + shape[:ndims - 1] + + [functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)]) + + +def _NumpyScatterNd(ref, indices, updates, op): + ixdim = indices.shape[-1] + num_updates = indices.size // ixdim + total_nd = len(ref.shape) + slice_size = 1 + for i in range(ixdim, total_nd): + slice_size *= ref.shape[i] + flat_indices = _FlatInnerDims(indices) + flat_updates = updates.reshape((num_updates, slice_size)) + output_flat = _FlatOuterDims(ref, ixdim + 1) + for ix_updates, ix_output in enumerate(flat_indices): + ix_output = tuple(ix_output) + output_flat[ix_output] = op(output_flat[ix_output], + flat_updates[ix_updates]) + return output_flat.reshape(ref.shape) + + +def _NumpyUpdate(indices, updates, shape): + ref = np.zeros(shape, dtype=updates.dtype) + return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) + + +class ScatterNdTest(XLATestCase): + + def _VariableRankTest(self, + np_scatter, + tf_scatter, + vtype, + itype, + repeat_indices=False): + np.random.seed(8) + ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)] + indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)] + for ref_shape, indices_shape in zip(ref_shapes, indices_shapes): + num_updates = indices_shape[0] + ixdim = indices_shape[-1] + + indexable_area_shape = () + for i in range(ixdim): + indexable_area_shape += (ref_shape[i],) + all_indices = [ + list(coord) + for coord, _ in np.ndenumerate(np.empty(indexable_area_shape, vtype)) + ] + np.random.shuffle(all_indices) + indices = np.array(all_indices[:num_updates]) + + if num_updates > 1 and repeat_indices: + indices = indices[:num_updates // 2] + for _ in range(num_updates - num_updates // 2): + indices = np.append( + indices, [indices[np.random.randint(num_updates // 2)]], axis=0) + np.random.shuffle(indices) + indices = _AsType(indices[:num_updates], itype) + + updates_shape = (num_updates,) + for i in range(ixdim, len(ref_shape)): + updates_shape += (ref_shape[i],) + updates = _AsType(np.random.randn(*(updates_shape)), vtype) + + # Scatter via numpy + np_out = np_scatter(indices, updates, ref_shape) + # Scatter via tensorflow + tf_out = tf_scatter(indices, updates, ref_shape) + + self.assertAllClose(np_out, tf_out) + + def _VariableRankTests(self, np_scatter, tf_scatter): + for vtype in self.numeric_types: + for itype in set([np.int32, np.int64]).intersection(set(self.int_types)): + self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) + + def _runScatterNd(self, indices, updates, shape): + with self.test_session(): + updates_placeholder = array_ops.placeholder(updates.dtype) + indices_placeholder = array_ops.placeholder(indices.dtype) + with self.test_scope(): + output = array_ops.scatter_nd(indices_placeholder, updates_placeholder, + shape) + feed_dict = {updates_placeholder: updates, indices_placeholder: indices} + return output.eval(feed_dict=feed_dict) + + def testSimple(self): + indices = np.array([[4], [3], [1], [7]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + + def testSimple2(self): + indices = np.array([[1, 0], [1, 1]], dtype=np.int32) + updates = np.array([11., 12.], dtype=np.float32) + expected = np.array([[0., 0.], [11., 12.], [0., 0.]], dtype=np.float32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) + + def testSimple3(self): + indices = np.array([[1]], dtype=np.int32) + updates = np.array([[11., 12.]], dtype=np.float32) + expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) + + def testVariableRankUpdate(self): + self._VariableRankTests(_NumpyUpdate, self._runScatterNd) + + def testExtraIndicesDimensions(self): + indices = np.zeros([1, 1, 2], np.int32) + updates = np.zeros([1, 1], np.int32) + expected = np.zeros([2, 2], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2])) + + def testRank3InvalidShape1(self): + indices = np.zeros([3, 2, 2], np.int32) + updates = np.zeros([2, 2, 2], np.int32) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Must have updates.shape"): + self._runScatterNd(indices, updates, [2, 2, 2]) + + def testRank3InvalidShape2(self): + indices = np.zeros([2, 2, 1], np.int32) + updates = np.zeros([2, 2], np.int32) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Must have updates.shape"): + self._runScatterNd(indices, updates, [2, 2, 2]) + + def testScatterOutOfRange(self): + updates = np.array([-3, -4, -5]).astype(np.float32) + + # Indices all in range, no problem. + indices = np.array([[2], [0], [5]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + + # Indices out of range should not fail. It produces implementation-defined + # output. + indices = np.array([[-1], [0], [5]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + indices = np.array([[2], [0], [6]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 260a04421b62310c109d8f0ea72875a50c234bb0..4a9c0e7471f9cdb2a47b54705495d2dda9748890 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -60,6 +60,14 @@ class SegmentReductionOpsTest(XLATestCase): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) + def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([6, 3, 0, 6], dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5, 6], dtype=dtype), + np.array([3, -1, 0, 1, 0, -1, 3], dtype=np.int32), 4)) + def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): for dtype in self.numeric_types: data = np.array( diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index a7cbfb04003c397212a35e16c6b23d7c2a18f7df..305ca0c6b78d3ef985deb38816f9388e7983906b 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest @@ -137,6 +138,34 @@ class StridedSliceTest(XLATestCase): self.assertAllEqual([6, 4], result) + def test2DDegenerate(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + o = array_ops.strided_slice(i, [-1, 0], [0, 3]) + params = { + i: [[0, 1, 2], + [3, 4, 5]] + } + result = o.eval(feed_dict=params) + + self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + + def test2DDegenerateNegativeStride(self): + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[2, 3]) + with self.test_scope(): + o = array_ops.strided_slice(i, [0, 0], [-1, 3], [-1, 1]) + params = { + i: [[0, 1, 2], + [3, 4, 5]] + } + result = o.eval(feed_dict=params) + + self.assertEqual(tensor_shape.TensorShape((0, 3)), result.shape) + def test3D(self): for dtype in self.numeric_types: with self.test_session(): diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 3c7dfef03dfb5d86dd63fd4aa84ad56081833035..fb82c2601c432cee425a46a3b6dc2c55febeda87 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -312,6 +312,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index 82923722c54d235716b9138d95a75a441df924ca..6f46532419d3389bafe8c3bf41fa41e8a3e173b7 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -37,7 +37,7 @@ Status BackwardsConstAnalysis(const Graph& g, }; Status status; - std::unordered_set must_be_const; + std::unordered_set must_be_const; auto visit = [&status, &metadata_ops, &must_be_const, compile_time_const_args](Node* node) { if (!status.ok()) return; @@ -55,7 +55,7 @@ Status BackwardsConstAnalysis(const Graph& g, compile_time_const_args->at(index) = true; return; } - for (Node* pred : node->in_nodes()) { + for (const Node* pred : node->in_nodes()) { must_be_const.insert(pred); } return; diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index bf304102ede610e952a424f0b24505a14692f8ed..8b7beef83ec2ed0df780d6a9cb2a4bcf737d008b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -285,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, Status FunctionalizeLoop(Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph); + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); // Split loop-varying Enter nodes with multiple successors. If the same // Tensor is fed as input to multiple loop arguments, we may end up with a @@ -470,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); static std::atomic sequence_num(0LL); @@ -551,7 +552,8 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph); + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); return Status::OK(); } @@ -581,14 +583,16 @@ class FunctionalizeCond { // CondArgNode represents a input to the conditional and its corresponding // switch nodes. struct CondArgNode { - explicit CondArgNode(Node* input) : input(input) {} + explicit CondArgNode(Node* src, int src_output) + : src(src), src_output(src_output) {} string ToString() const { - return strings::StrCat("input=", input->name(), - " switches=", NodesToString(switch_nodes)); + return strings::StrCat("src=", src->name(), ":", src_output, + " switches=", NodesToString(switches)); } - Node* input; - std::vector switch_nodes; + Node* src; + int src_output; + std::vector switches; }; using CondArgNodes = std::vector; @@ -602,15 +606,23 @@ class FunctionalizeCond { int count; }; - struct PredicateSwitches { - explicit PredicateSwitches(Node* predicate) : predicate(predicate) {} + // Group of switch nodes that will be part of the same XlaIf. + struct SwitchCluster { + explicit SwitchCluster(const Edge* predicate_edge) + : predicate_edge(predicate_edge) {} + string ToString() const { + return strings::StrCat(name, " predicate=", predicate_edge->src()->name(), + " switches=", NodesToString(switches)); + } - Node* predicate; + string name; + const Edge* predicate_edge; std::vector switches; }; - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : library_(library), graph_(graph) {} + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + bool dump_graphs) + : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} // Perform the actual cond functionalization. Iterate over groups of switch // nodes (linked by common predicate), from innermost to outermost, and @@ -621,40 +633,38 @@ class FunctionalizeCond { // frontier (the nodes where the cond ends). StatusOr, std::unordered_set>> - DetermineBranchMapAndFrontier(const std::vector& switches); + DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); // Returns XlaIf node created from subgraph of merge and switch nodes. This // encapsulates the process of extracting the bodies needed for the then and // else branch, creates a XlaIf node, removing the nodes of the branches from // the graph and replacing the merge node with a XlaIf. StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, - const std::vector& merge_nodes, - Node* predicate); + const SwitchCluster& switch_cluster, + const std::vector& switches); // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, - const std::vector& merge_nodes, - Node* predicate); + const SwitchCluster& switch_cluster, + const std::vector& merge_nodes); // Extracts a function body corresponding to the given input edge of the merge // node. Status ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, + const std::vector& switches, const std::vector& merge_nodes, int input_edge, Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate, - Node* if_node); + Status AddInputEdges(const CondArgNodes& cond_arg_nodes, + const Edge* predicate_edge, Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Returns the switches of graph_ (along with grouping predicates) in - // postorder. Dead switch nodes are skipped and removed from the graph. - std::vector DeterminePredicateSwitchOrder(); + // Returns the switch clusters of graph_ in postorder. Dead switch nodes are + // skipped and removed from the graph. + StatusOr> DeterminePredicateSwitchOrder(); // Update the state for destination based on the state of source and the node // being updated. @@ -677,6 +687,7 @@ class FunctionalizeCond { FunctionLibraryDefinition* library_; Graph* graph_; + bool dump_graphs_; }; bool IsDeadSwitch(const Node* node) { @@ -724,10 +735,13 @@ Status FunctionalizeCond::ValidateFrontier( ") in both Else and Then branch should be in Both."); } } - if (pending[kBoth].empty() && pending[kThenBranch].empty() && - pending[kElseBranch].empty()) { - return errors::Internal("Unexpected empty frontier for switch nodes"); - } + // An empty frontier indicates a dead switch. Above we attempt to remove dead + // switch nodes, but not all are removed so don't treat it as an error yet. + // TODO(jpienaar): Find out why dead switch nodes remain. + // if (pending[kBoth].empty() && pending[kThenBranch].empty() && + // pending[kElseBranch].empty()) { + // return errors::Internal("Unexpected empty frontier for switch nodes"); + // } return Status::OK(); } @@ -745,8 +759,8 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, if (IsMerge(dst)) { dst_state->branch = Branch::kBoth; } else { - return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", - dst_state->ToString(), " for ", + return errors::Internal("Illegal merge:\n", src_state.ToString(), + " with ", dst_state->ToString(), " for\n", dst->DebugString()); } } @@ -754,48 +768,217 @@ Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, return Status::OK(); } -std::vector +StatusOr> FunctionalizeCond::DeterminePredicateSwitchOrder() { + struct Cluster { + bool operator==(const Cluster& other) const { + return representative == other.representative; + } + int representative = -1; + }; + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Identify dead switches; + // * Initialize the cluster's representative; + std::vector> clusters(graph_->num_node_ids()); std::vector dead_switches; std::vector switch_order; - DFS(*graph_, nullptr, [this, &dead_switches, &switch_order](Node* n) { + std::vector rev_topo_sorted_nodes; + DFS(*graph_, nullptr, [&](Node* n) { + clusters[n->id()].Get().representative = n->id(); if (IsSwitch(n)) { if (IsDeadSwitch(n)) { dead_switches.push_back(n); } else { + rev_topo_sorted_nodes.push_back(n); switch_order.push_back(n); } + } else if (n->IsOp()) { + // Exclude src and sink nodes from further consideration. + rev_topo_sorted_nodes.push_back(n); } }); + std::vector switch_clusters; + // Return early if there are no switches in the graph. + if (switch_order.empty()) { + return switch_clusters; + } + // Remove all dead switch nodes. for (Node* n : dead_switches) { VLOG(2) << "Removing dead switch: " << n->DebugString(); graph_->RemoveNode(n); } - std::vector predicate_switch_order; - if (switch_order.empty()) { - return predicate_switch_order; + // Identify switch nodes that are part of the same control flow context by + // considering the operands of operations: an operation is part of the same + // control context as its operands unless the operation is a switch. Control + // dependencies are considered part of the same control flow context if the + // switch depth is the same (see comment below). + + // entry_cluster records the input cluster to a switch node. This is used when + // merging with a merge node where the dst's cluster is merged with the entry + // cluster of the merge node's cluster (which corresponds to a switch cluster + // and so has an entry cluster). + std::unordered_map*> entry_cluster; + + // Returns the output cluster of a node. Where the output cluster is cluster + // where the output of the node is used. For non-merge nodes this is simply + // the cluster they are part of, while for merge nodes it is the entry cluster + // of the cluster they are part of (this will correspond to the entry node of + // a switch node that dominates the merge). + auto find_output_cluster = [&](Node* n) { + UnionFind* cluster = &clusters[n->id()]; + if (!IsMerge(n)) return cluster; + auto it = entry_cluster.find(clusters[n->id()].Get().representative); + // If the cluster is not found in the entry_cluster map then an + // instruction not dominated by a switch node has been merged into the + // cluster of the merge. This indicates a failure of the clustering. + CHECK(it != entry_cluster.end()) + << "Unable to find entry for n=" << n->id() << " (" + << cluster->Get().representative << ")"; + return it->second; + }; + + // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. + std::vector switch_depth(graph_->num_node_ids()); + for (auto it = rev_topo_sorted_nodes.rbegin(); + it != rev_topo_sorted_nodes.rend(); ++it) { + Node* n = *it; + + // Compute switch depth. + int new_switch_depth = 0; + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + new_switch_depth = std::max( + new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); + } + switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); + + // Only merge the input operands of a switch. The switch's clustering itself + // is determined by the interaction of the switch's outputs. + if (IsSwitch(n)) { + Node* input; + TF_CHECK_OK(n->input_node(0, &input)); + entry_cluster[n->id()] = find_output_cluster(input); + UnionFind* cluster = entry_cluster[n->id()]; + int cluster_depth = switch_depth[cluster->Get().representative]; + // Merge the inputs of the switch node with one another. This results in + // predicates and control input residing in the same cluster. + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + UnionFind* src_cluster = find_output_cluster(src); + int src_cluster_depth = switch_depth[src_cluster->Get().representative]; + if (cluster_depth != src_cluster_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Switch ('", + n->name(), "') has operands ('", input->name(), "' and '", + src->name(), "') that have different switch depths (", + cluster_depth, " != ", src_cluster_depth, ")"); + } + cluster->Merge(src_cluster); + } + continue; + } + + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (!src->IsOp()) continue; + UnionFind* cluster = find_output_cluster(src); + // Merge a node with its data operands and with its control operands if + // the src and dst are in the same ControlContext. The ControlContext is + // not explicitly available here, and instead the switch depth is used as + // a proxy here. Due to the invariant that control edges can only be from + // a containing scope to an inner scope or from the inner scope to its + // containing scope (for exit nodes), the switch depth will only match if + // the src and dst are in the same ControlContext. Control edges between + // ControlContexts are handled during the extraction. + int src_id = cluster->Get().representative; + int src_depth = switch_depth[src_id]; + if (!e->IsControlEdge() || new_switch_depth == src_depth) { + if (src_depth != new_switch_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Operand ('", + src->name(), "') and operator ('", n->name(), + "') have different switch depths (", src_depth, + " != ", new_switch_depth, ")"); + } + cluster->Merge(&clusters[n->id()]); + } + } + } + + if (dump_graphs_) { + // Mark the switch cluster each node is part of. + for (Node* n : graph_->nodes()) { + n->ClearAttr("_XlaFunctionalizeSwitchGroup"); + n->AddAttr("_XlaFunctionalizeSwitchGroup", + clusters[n->id()].Get().representative); + } + LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " + << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, + library_); + } + + // Verify all the nodes of a cluster are at the same depth. + std::unordered_map> cluster_to_depth_node; + for (Node* n : graph_->nodes()) { + int depth = switch_depth[n->id()]; + int cluster_rep = clusters[n->id()].Get().representative; + auto it = cluster_to_depth_node.find(cluster_rep); + if (it == cluster_to_depth_node.end()) { + cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); + } else { + if (it->second.first != depth) { + return errors::Internal( + "Illegal clustering created, mismatch in depths:", "\n\t", + n->DebugString(), "(", clusters[n->id()].Get().representative, + ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), + "(", clusters[n->id()].Get().representative, ") at depth ", + it->second.first); + } + } } + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second.representative)); + } + }; + // Merge Switch nodes with common predicate. - std::unordered_map predicate_index; + std::unordered_map, int, Hash> predicate_index; // The nodes in switch_order are in reverse topological order, but the // clustered switches need not be (i.e., when considered as a cluster one // element of a cluster may be later in the topological order than another // node whose cluster is later in the topological order of clustered // switches). for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { - Node* pred; - TF_CHECK_OK((*it)->input_node(1, &pred)); - if (predicate_index.find(pred) == predicate_index.end()) { - predicate_index[pred] = predicate_switch_order.size(); - predicate_switch_order.emplace_back(pred); + const Edge* pred_edge; + TF_CHECK_OK((*it)->input_edge(1, &pred_edge)); + // The predicate can be preceded by a identity node. Look through identity + // nodes to predicate. + while (pred_edge->src()->IsIdentity()) { + TF_CHECK_OK(pred_edge->src()->input_edge(0, &pred_edge)); } - predicate_switch_order[predicate_index[pred]].switches.push_back(*it); + auto repr = std::make_pair(pred_edge->src(), clusters[(*it)->id()].Get()); + if (predicate_index.find(repr) == predicate_index.end()) { + predicate_index[repr] = switch_clusters.size(); + switch_clusters.emplace_back(pred_edge); + // Generate a name by concatenating with the cluster representative as + // there could be multiple switch clusters with the same predicate. + switch_clusters[predicate_index[repr]].name = strings::StrCat( + pred_edge->src()->name(), "_", repr.second.representative, "_If"); + } + switch_clusters[predicate_index[repr]].switches.push_back(*it); } - return predicate_switch_order; + + return switch_clusters; } StatusOr> @@ -843,10 +1026,10 @@ StatusOr< std::pair, std::unordered_set>> FunctionalizeCond::DetermineBranchMapAndFrontier( - const std::vector& switches) { + const SwitchCluster& switch_cluster) { std::unordered_map branch_map; std::unordered_set frontier; - std::vector stack = switches; + std::vector stack = switch_cluster.switches; std::vector visited(graph_->num_node_ids(), false); while (!stack.empty()) { Node* n = stack.back(); @@ -869,9 +1052,12 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( ForwardFlowNode& ffn = branch_map[out]; if (IsSwitch(n)) { int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); - TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + Join(ForwardFlowNode(Branch(index)), out, &ffn), " when joining ", + e->DebugString()); } else { - TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + TF_RETURN_WITH_CONTEXT_IF_ERROR(Join(branch_map[n], out, &ffn), + " when joining ", e->DebugString()); } if (IsMerge(out)) { if (out->in_edges().size() == ffn.count) { @@ -888,7 +1074,7 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( } } - if (VLOG_IS_ON(2)) { + if (dump_graphs_) { for (const auto& kv : branch_map) { // Append attribute to the graph if running with logging to make the // changes clearer in the visualization. @@ -900,41 +1086,50 @@ FunctionalizeCond::DetermineBranchMapAndFrontier( } Status FunctionalizeCond::FunctionalizeInternal() { - std::vector predicate_switch_order = - DeterminePredicateSwitchOrder(); + TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, + DeterminePredicateSwitchOrder()); // Iterate from innermost set of clustered switches to outermost, replacing // matching switch->merge subgraphs with single XlaIf nodes. for (auto it = predicate_switch_order.rbegin(); it != predicate_switch_order.rend(); ++it) { auto& ps = *it; - VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " (" - << ps.predicate->name() << ")"; + VLOG(3) << "Flow down from: " << ps.ToString(); std::unordered_map branch_map; std::unordered_set frontier; TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), - DetermineBranchMapAndFrontier(ps.switches)); + DetermineBranchMapAndFrontier(ps)); - VLOG(2) << "FunctionalizeControlFlow (before XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_bc", *graph_); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, + library_); TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second)); + } + }; + // Sort the merge and switch nodes using NodeCmp. The switch-nodes are // further grouped (post sorting) by input to the switch node as in the // functionalized form each input will be passed in only once. This grouping // should retain the sorted order. CondArgNodes cond_arg_nodes; - std::unordered_map input_index; std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); + std::unordered_map, int, Hash> input_index; for (Node* switch_node : ps.switches) { - Node* in; - TF_RETURN_IF_ERROR(switch_node->input_node(0, &in)); - if (input_index.find(in) == input_index.end()) { - input_index[in] = cond_arg_nodes.size(); - cond_arg_nodes.emplace_back(in); + const Edge* e; + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &e)); + std::pair key = std::make_pair(e->src(), e->src_output()); + if (input_index.find(key) == input_index.end()) { + input_index[key] = cond_arg_nodes.size(); + cond_arg_nodes.emplace_back(key.first, key.second); } - cond_arg_nodes.at(input_index.at(in)).switch_nodes.push_back(switch_node); + cond_arg_nodes.at(input_index.at(key)).switches.push_back(switch_node); } std::vector merge_nodes(frontier.begin(), frontier.end()); std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); @@ -943,9 +1138,8 @@ Status FunctionalizeCond::FunctionalizeInternal() { EnsureDominanceAndReturnNonDominatedControlNodes( branch_map, ps.switches)); - TF_ASSIGN_OR_RETURN( - Node * if_node, - ConvertToXlaIf(cond_arg_nodes, ps.switches, merge_nodes, ps.predicate)); + TF_ASSIGN_OR_RETURN(Node * if_node, + ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); for (Node* old : old_control_nodes) { graph_->AddControlEdge(old, if_node); } @@ -954,25 +1148,26 @@ Status FunctionalizeCond::FunctionalizeInternal() { graph_->RemoveNode(del_kv.first); } for (auto& kv : cond_arg_nodes) { - for (Node* node : kv.switch_nodes) { + for (Node* node : kv.switches) { graph_->RemoveNode(node); } } - VLOG(2) << "FunctionalizeControlFlow (after XlaIf conversion): " - << dump_graph::DumpGraphToFile("functionalize_ac", *graph_); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, + library_); } return Status::OK(); } StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, - const std::vector& merge_nodes, Node* predicate) { - VLOG(2) << "Build if op for " << NodesToString(merge_nodes) << " with input " - << NodesToString(switch_nodes); + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(2) << "Build if op for " << switch_cluster.name; NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder(strings::StrCat(predicate->name(), "_If"), "XlaIf"); + NodeDefBuilder builder(switch_cluster.name, "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -982,12 +1177,9 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( body_name.set_name( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR( - ExtractBody(cond_arg_nodes, switch_nodes, merge_nodes, i, body.get())); + TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, + merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); - VLOG(4) << "FunctionalizeControlFlow (" << branch[i] << "): " - << dump_graph::DumpGraphToFile( - strings::StrCat("functionalize_", branch[i]), *body); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); @@ -999,7 +1191,7 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( DataTypeVector in_arg_types; for (auto& kv : cond_arg_nodes) { bool inserted = false; - for (const Node* arg : kv.switch_nodes) { + for (const Node* arg : kv.switches) { const Edge* in_edge; TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); if (in_edge->IsControlEdge()) { @@ -1026,10 +1218,12 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(predicate->assigned_device_name()); + builder.Device(switch_cluster.predicate_edge->src()->assigned_device_name()); // Conditional should be the first input ... - builder.Input( - NodeDefBuilder::NodeOut(predicate->name(), 0, predicate->output_type(0))); + builder.Input(NodeDefBuilder::NodeOut( + switch_cluster.predicate_edge->src()->name(), + switch_cluster.predicate_edge->src_output(), + switch_cluster.predicate_edge->src()->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1039,7 +1233,7 @@ StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( } Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, - const std::vector& switch_nodes, + const std::vector& switches, const std::vector& merge_nodes, int input_edge, Graph* body) { VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " @@ -1049,7 +1243,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, int arg_count = 0; for (auto& kv : cond_arg_nodes) { Node* arg_node = nullptr; - for (const auto* arg : kv.switch_nodes) { + for (const auto* arg : kv.switches) { DataType dtype = arg->input_type(0); if (arg_node == nullptr) { TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); @@ -1073,8 +1267,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, node_map.at(in->id()) = body->CopyNode(in); } - if (std::find(switch_nodes.begin(), switch_nodes.end(), in) == - switch_nodes.end()) { + if (std::find(switches.begin(), switches.end(), in) == switches.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1090,24 +1283,17 @@ Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, } Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, - Node* predicate, Node* if_node) { + const Edge* predicate_edge, + Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); int index = 0; - graph_->AddEdge(predicate, 0, if_node, index++); - for (auto& kv : cond_arg_nodes) { - bool inserted = false; - for (const Node* arg : kv.switch_nodes) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph_->AddControlEdge(in_edge->src(), if_node); - } else { - if (!inserted) { - graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, - index++); - inserted = true; - } - } + graph_->AddEdge(predicate_edge->src(), predicate_edge->src_output(), if_node, + index++); + for (auto& arg : cond_arg_nodes) { + if (arg.src_output == Graph::kControlSlot) { + graph_->AddControlEdge(arg.src, if_node); + } else { + graph_->AddEdge(arg.src, arg.src_output, if_node, index++); } } return Status::OK(); @@ -1128,10 +1314,10 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return errors::Unimplemented("Output of index (", edge->src_output(), ") of merge node ", node->name()); } - graph_->RemoveEdge(edge); int src_output = dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph_->RemoveEdge(edge); graph_->AddEdge(if_node, src_output, dst, dst_input); } } @@ -1139,16 +1325,17 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, } StatusOr FunctionalizeCond::ConvertToXlaIf( - const CondArgNodes& cond_arg_nodes, const std::vector& switch_nodes, - const std::vector& merge_nodes, Node* predicate) { - VLOG(1) << "ConvertToXlaIf for " << NodesToString(switch_nodes) << " -> " + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " << NodesToString(merge_nodes); // Extract bodies and builds a If operator. TF_ASSIGN_OR_RETURN( Node * if_node, - BuildAndAddXlaIfOp(cond_arg_nodes, switch_nodes, merge_nodes, predicate)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_arg_nodes, predicate, if_node)); + BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); + TF_RETURN_IF_ERROR( + AddInputEdges(cond_arg_nodes, switch_cluster.predicate_edge, if_node)); TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); return if_node; @@ -1157,18 +1344,20 @@ StatusOr FunctionalizeCond::ConvertToXlaIf( Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library); + FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); return fc.FunctionalizeInternal(); } } // namespace -// Transformation that converts Tensorflow's graph control flow constructs into +// Transformation that converts TensorFlow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + << dump_graph::DumpGraphToFile("functionalize_initial", *graph, + library); + // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. @@ -1180,7 +1369,8 @@ Status FunctionalizeControlFlow(Graph* graph, for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; - VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name << " frame: " << (cf.frame ? cf.frame->name() : "---") << " parent_frame: " << (cf.parent_frame ? cf.parent_frame->name() : "---"); @@ -1248,7 +1438,8 @@ Status FunctionalizeControlFlow(Graph* graph, TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph); + << dump_graph::DumpGraphToFile("functionalize_final", *graph, + library); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 71f12a13339b9b5495631b8f9350579f6a0785a3..bc7276c3afd5060d6faeceb4d479416299ecc5da 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -38,10 +38,11 @@ namespace { // Returns the names of the "then" and "else" functions for the XlaIf node in a // graph. -Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, - NameAttrList* else_fn) { +Status FindIfThenAndElse(const GraphDef& graph, string* op_name, + NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "XlaIf") { + *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); *then_fn = *result; @@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) { GraphDef graph_def; graph.ToGraphDef(&graph_def); + string op_name; NameAttrList then_fn; NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); InstantiationResultForTest else_result; TF_EXPECT_OK( InstantiateFunctionForTest(else_fn.name(), library, &else_result)); @@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Less_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 1418d95956e1536292d58dfc4c2b53c53421fa94..b20c1ffc7d8956f3f5530ee63e9b711a26439be5 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -130,11 +130,11 @@ Status GraphCompiler::Compile() { // Set up inputs from outputs of previous nodes. for (auto* e : n->in_edges()) { if (e->IsControlEdge()) continue; - Node* src = e->src(); + const Node* src = e->src(); TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; - tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()]; + tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output()); } OpKernelContext op_context(¶ms, n->num_outputs()); diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4c6b29bd015d048f842906cc509a6ed564629b73..d2fa933cf9c085f92b2f442827a94d72938e4bb2 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -32,6 +32,7 @@ tf_kernel_library( "dynamic_stitch_op.cc", "elu_op.cc", "extract_image_patches_op.cc", + "fake_quantize_ops.cc", "fft_ops.cc", "fill_op.cc", "function_ops.cc", @@ -64,6 +65,7 @@ tf_kernel_library( "reverse_op.cc", "reverse_sequence_op.cc", "scan_ops.cc", + "scatter_nd_op.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", @@ -96,12 +98,15 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:triangular_solve", "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 344a2ab2b6835c518c41de6f7a30fb2a34d130d2..cbade79e85eed10ecb5ead7151ee778c86a0de37 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,7 +159,9 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), +REGISTER_XLA_OP(Name("BatchToSpaceND") + .CompileTimeConstInput("block_shape") + .CompileTimeConstInput("crops"), BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..453a32c494b42e9922bc35fc526f3306530054fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +// Gymnastics with nudged zero point is to ensure that the real zero maps to +// an integer, which is required for e.g. zero-padding in convolutional layers. +void CpuNudge(const float min, const float max, const float quant_min, + const float quant_max, float* nudged_min, float* nudged_max, + float* scale) { + *scale = (max - min) / (quant_max - quant_min); + + const float zero_point_from_min = quant_min - min / *scale; + float nudged_zero_point; + if (zero_point_from_min <= quant_min) { + nudged_zero_point = quant_min; + } else if (zero_point_from_min >= quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = std::round(zero_point_from_min); + } + + *nudged_min = (quant_min - nudged_zero_point) * (*scale); + *nudged_max = (quant_max - nudged_zero_point) * (*scale); +} + +// An XLA version of CpuNudge(). +void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, + const xla::ComputationDataHandle& min, + const xla::ComputationDataHandle& max, + const float quant_min_value, const float quant_max_value, + xla::ComputationDataHandle* nudged_min, + xla::ComputationDataHandle* nudged_max, + xla::ComputationDataHandle* scale) { + *scale = b->Div(b->Sub(max, min), + XlaHelpers::FloatLiteral(b, data_type, + quant_max_value - quant_min_value)); + xla::ComputationDataHandle quant_min = + XlaHelpers::FloatLiteral(b, data_type, quant_min_value); + xla::ComputationDataHandle zero_point_from_min = + b->Sub(quant_min, b->Div(min, *scale)); + xla::ComputationDataHandle quant_max = + XlaHelpers::FloatLiteral(b, data_type, quant_max_value); + xla::ComputationDataHandle nudged_zero_point = + b->Select(b->Le(zero_point_from_min, quant_min), quant_min, + b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, + b->Round(zero_point_from_min))); + *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); +} + +xla::ComputationDataHandle Quantize( + xla::ComputationBuilder* b, const xla::ComputationDataHandle& input, + const DataType data_type, + const xla::ComputationDataHandle& nudged_input_min, + const xla::ComputationDataHandle& nudged_input_max, + const xla::ComputationDataHandle& input_scale) { + xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); + xla::ComputationDataHandle inv_scale = b->Div(one, input_scale); + xla::ComputationDataHandle half = + XlaHelpers::FloatLiteral(b, data_type, 0.5f); + + xla::ComputationDataHandle clamped = + b->Clamp(nudged_input_min, input, nudged_input_max); + xla::ComputationDataHandle clamped_shifted = + b->Sub(clamped, nudged_input_min); + xla::ComputationDataHandle rounded = + b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); + return b->Add(b->Mul(rounded, input_scale), nudged_input_min); +} + +class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + + float input_min, input_max; + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); + CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_, + &nudged_input_max_, &input_scale_); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const DataType data_type = ctx->input_type(0); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); + xla::ComputationDataHandle nudged_input_max = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); + xla::ComputationDataHandle input_scale = + XlaHelpers::FloatLiteral(b, data_type, input_scale_); + xla::ComputationDataHandle output = Quantize( + b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + ctx->SetOutput(0, output); + } + + private: + float quant_min_; + float quant_max_; + float nudged_input_min_; + float nudged_input_max_; + float input_scale_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp); + +class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + const float quant_min = narrow_range ? 1 : 0; + const float quant_max = (1 << num_bits) - 1; + + float input_min, input_max, scale; + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); + CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_, + &nudged_input_max_, &scale); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle gradient = ctx->Input(0); + const TensorShape gradient_shape = ctx->InputShape(0); + xla::ComputationDataHandle input = ctx->Input(1); + const DataType data_type = ctx->input_type(1); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); + xla::ComputationDataHandle nudged_input_max = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); + + xla::ComputationDataHandle between_nudged_min_max = + b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::ComputationDataHandle zeroes = b->Broadcast( + XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes()); + xla::ComputationDataHandle output = + b->Select(between_nudged_min_max, gradient, zeroes); + ctx->SetOutput(0, output); + } + + private: + float nudged_input_min_; + float nudged_input_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"), + FakeQuantWithMinMaxArgsGradOp); + +class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const DataType data_type = ctx->input_type(0); + xla::ComputationDataHandle input_min = ctx->Input(1); + xla::ComputationDataHandle input_max = ctx->Input(2); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, + &nudged_input_min, &nudged_input_max, &input_scale); + + xla::ComputationDataHandle output = Quantize( + b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + ctx->SetOutput(0, output); + } + + private: + float quant_min_; + float quant_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp); + +class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle gradient = ctx->Input(0); + const TensorShape gradient_shape = ctx->InputShape(0); + xla::ComputationDataHandle input = ctx->Input(1); + const DataType data_type = ctx->input_type(1); + xla::ComputationDataHandle input_min = ctx->Input(2); + xla::ComputationDataHandle input_max = ctx->Input(3); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, + &nudged_input_min, &nudged_input_max, &input_scale); + + xla::ComputationDataHandle between_nudged_min_max = + b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type); + xla::ComputationDataHandle zeroes = + b->Broadcast(zero, gradient_shape.dim_sizes()); + xla::ComputationDataHandle output0 = + b->Select(between_nudged_min_max, gradient, zeroes); + ctx->SetOutput(0, output0); + + xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); + xla::ComputationDataHandle output1 = + b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, + *ctx->GetOrCreateAdd(data_type)); + ctx->SetOutput(1, output1); + + xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); + xla::ComputationDataHandle output2 = + b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, + *ctx->GetOrCreateAdd(data_type)); + ctx->SetOutput(2, output2); + } + + private: + float quant_min_; + float quant_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"), + FakeQuantWithMinMaxVarsGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index e9af1e9c2fcb4922ea3570516419abd04a611a04..7945c05af40df21a798a2cff51fe7f8e935793f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -32,12 +33,12 @@ Status XlaGather(const xla::ComputationDataHandle& input, DataType dtype, DataType index_type, xla::ComputationBuilder* builder, xla::ComputationDataHandle* gather_output) { - // If the indices are N-dimensional, then the last dimension of indices should - // be of size N and correspond to the N indices. - int64 num_axes = 1; + // If the indices are N-dimensional, then the minor dimension of indices + // should be of size N and correspond to the N indices. + int64 num_index_dims = 1; if (indices_are_nd) { CHECK_GE(indices_shape.dims(), 1); - num_axes = indices_shape.dim_size(indices_shape.dims() - 1); + num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); indices_shape.RemoveLastDims(1); } @@ -46,15 +47,15 @@ Status XlaGather(const xla::ComputationDataHandle& input, // input, the output is returned with shape: // input.shape[:axis] + indices.shape + input.shape[axis+1:] - const int num_indices = indices_shape.num_elements(); + const int64 num_indices = indices_shape.num_elements(); TensorShape input_shape_pre_axis(input_shape); input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); TensorShape input_shape_post_axis(input_shape); - input_shape_post_axis.RemoveDimRange(0, axis + num_axes); + input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); // Each slice of the input tensor has shape: // [, 1, ..., 1, ] TensorShape slice_shape(input_shape); - for (int64 i = 0; i < num_axes; ++i) { + for (int64 i = 0; i < num_index_dims; ++i) { slice_shape.set_dim(axis + i, 1); } @@ -79,7 +80,7 @@ Status XlaGather(const xla::ComputationDataHandle& input, return Status::OK(); } - for (int64 i = 0; i < num_axes; ++i) { + for (int64 i = 0; i < num_index_dims; ++i) { if (input_shape.dim_size(axis + i) == 0) { return errors::InvalidArgument("Gather dimension ", axis + i, " is of size zero in tensor with shape ", @@ -91,57 +92,30 @@ Status XlaGather(const xla::ComputationDataHandle& input, // iteration. If there is an axis dimension, we must leave it alone. std::vector flat_indices_shape = {num_indices}; if (indices_are_nd) { - flat_indices_shape.push_back(num_axes); + flat_indices_shape.push_back(num_index_dims); } // Specify the shape of the loop-carried Tensor tuple. - xla::PrimitiveType ptype; - TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); - xla::PrimitiveType idxtype; - TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype)); - std::vector tuple_shapes( - {// The iteration counter i is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(idxtype, {}), - // The input array has shape input_shape. Loop invariant. - xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), - // The gather indices are reshaped to flat_indices_shape. Loop invariant. - xla::ShapeUtil::MakeShape(idxtype, flat_indices_shape), - // The output array, which is updated on each loop iteration. - xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); // Construct the initial values of the loop-carried Tensors. - auto init_i = XlaHelpers::Zero(builder, index_type); + auto flat_indices = builder->Reshape(indices, flat_indices_shape); auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), loop_out_shape.dim_sizes()); - auto flat_indices = builder->Reshape(indices, flat_indices_shape); - auto init = builder->Tuple({init_i, input, flat_indices, init_out}); - - // Construct the while loop condition (i < num_indices) - std::unique_ptr condb = - builder->CreateSubBuilder("GatherWhileCond"); - condb->Lt(condb->GetTupleElement( - condb->Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - XlaHelpers::IntegerLiteral(condb.get(), index_type, num_indices)); - auto cond_status = condb->Build(); - auto cond = cond_status.ConsumeValueOrDie(); + auto init = {input, flat_indices, init_out}; // Construct the while loop body's function. The implementation of gather is: // for i in range(num_indices): // index = dynamic-slice(indices, i) // xi = dynamic-slice(input, index) // output = dynamic-update-slice(output, xi, i) - std::unique_ptr bodyb = - builder->CreateSubBuilder("GatherWhileBody"); - { - // The four loop carried values. - auto loop_tuple = bodyb->Parameter(0, tuple_shape, "GatherWhileTuple"); - auto i = bodyb->GetTupleElement(loop_tuple, 0); - auto input = bodyb->GetTupleElement(loop_tuple, 1); - auto indices = bodyb->GetTupleElement(loop_tuple, 2); - auto output = bodyb->GetTupleElement(loop_tuple, 3); - - auto zero_index = XlaHelpers::Zero(bodyb.get(), index_type); + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* bodyb) { + auto input = loop_vars[0]; + auto indices = loop_vars[1]; + auto output = loop_vars[2]; + + auto zero_index = XlaHelpers::Zero(bodyb, index_type); // Slice the i-th index from the indices array. xla::ComputationDataHandle index; @@ -150,7 +124,7 @@ Status XlaGather(const xla::ComputationDataHandle& input, // Slice out the entire nd index, if applicable. indices_offset = bodyb->Pad(indices_offset, zero_index, xla::MakeEdgePaddingConfig({{0, 1}})); - index = bodyb->DynamicSlice(indices, indices_offset, {1, num_axes}); + index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims}); index = bodyb->Collapse(index, {0, 1}); } else { index = bodyb->DynamicSlice(indices, indices_offset, {1}); @@ -174,16 +148,16 @@ Status XlaGather(const xla::ComputationDataHandle& input, // Update the output Tensor auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); - bodyb->Tuple({bodyb->Add(i, XlaHelpers::One(bodyb.get(), index_type)), - input, indices, updated_output}); - } - auto body_status = bodyb->Build(); - auto body = body_status.ConsumeValueOrDie(); + return std::vector{input, indices, + updated_output}; + }; // Construct the While loop, extract and reshape the output. - auto gather_while = builder->While(cond, body, init); - auto result = builder->GetTupleElement(gather_while, 3); - *gather_output = builder->Reshape(result, out_shape.dim_sizes()); + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype)); + TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn, + init, "gather", builder)); + *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes()); return Status::OK(); } @@ -250,9 +224,10 @@ class GatherNdOp : public XlaOpKernel { errors::InvalidArgument("params must be at least a vector")); OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape), errors::InvalidArgument("indices must be at least a vector")); - const int64 num_axes = indices_shape.dim_size(indices_shape.dims() - 1); + const int64 num_index_dims = + indices_shape.dim_size(indices_shape.dims() - 1); OP_REQUIRES( - context, num_axes <= params_shape.dims(), + context, num_index_dims <= params_shape.dims(), errors::InvalidArgument( "index innermost dimension length must be <= params rank; saw: ", indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index d2b1f7913ecc9113284827b53de8fb0e5b711322..39af662b638cb9d723118e58fcfc983633fed497 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -40,6 +40,7 @@ REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); +REGISTER_XLA_OP(Name("Snapshot"), IdentityOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8433a29c4e203cac726ee6bf7f67a863447326ed --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +// Check whether updates.shape = indices.shape[:batch_dim] + +// buffer_shape[num_index_dims:] +Status ValidateUpdateShape(const TensorShape& buffer_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape) { + if (indices_shape.dims() < 1) { + return errors::InvalidArgument( + "indices shape must have >= 1 dimension; got ", + indices_shape.DebugString()); + } + + const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); + const int64 batch_dim = indices_shape.dims() - 1; + + auto shape_err = [&]() { + return errors::InvalidArgument( + "Must have updates.shape = indices.shape[:batch_dim] + ", + "buffer_shape[num_index_dims:], got updates.shape: ", + updates_shape.DebugString(), + ", indices.shape: ", indices_shape.DebugString(), + ", buffer_shape: ", buffer_shape.DebugString(), + ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); + }; + + if (updates_shape.dims() < batch_dim) return shape_err(); + if (buffer_shape.dims() < + num_index_dims + (updates_shape.dims() - batch_dim)) { + return shape_err(); + } + if (updates_shape.dims() != + batch_dim + buffer_shape.dims() - num_index_dims) { + return shape_err(); + } + for (int d = 0; d < batch_dim; ++d) { + if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) { + return shape_err(); + } + } + for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { + if (updates_shape.dim_size(d + batch_dim) != + buffer_shape.dim_size(d + num_index_dims)) { + return shape_err(); + } + } + return Status::OK(); +} + +class ScatterNdOp : public XlaOpKernel { + public: + explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + DataType dtype = context->input_type(1); + + TensorShape indices_shape = context->InputShape(0); + TensorShape updates_shape = context->InputShape(1); + + TensorShape buffer_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape)); + + OP_REQUIRES( + context, TensorShapeUtils::IsVectorOrHigher(buffer_shape), + errors::InvalidArgument("Output must be at least 1-D, ", + "got shape: ", buffer_shape.DebugString())); + + OP_REQUIRES( + context, + buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 && + updates_shape.num_elements() == 0), + errors::InvalidArgument( + "Indices and updates specified for empty output. indices shape: ", + indices_shape.DebugString())); + + OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape, + updates_shape)); + + xla::ComputationBuilder* builder = context->builder(); + auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); + auto indices = context->Input(0); + auto updates = context->Input(1); + auto result = + XlaScatter(buffer, updates, indices, + /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + OP_REQUIRES_OK(context, result.status()); + context->SetOutput(0, result.ValueOrDie()); + } +}; + +REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstInput("shape"), ScatterNdOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h deleted file mode 100644 index a5ab7de17adb734014fe2dcbd60ae5c219c8e486..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Helper methods for XLA Scatter Ops. -#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ - -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/bcast.h" - -namespace tensorflow { - -// Adds to builder an XLA computation that performs a scatter-add of input (of -// shape input_shape) keyed on indices (of shape indices_shape). The shape -// of the Tensor returned by this is num_segments input_shape[indices.dims():] -// -static xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 num_segments, DataType dtype, - xla::ComputationBuilder* builder); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index c220edd588071ef262621784015d34cd475b2918..80d6df6c48b0141734dcee1c2a3c413926931feb 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,125 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" namespace tensorflow { - -xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 num_segments, DataType dtype, - xla::ComputationBuilder* builder) { - // Flatten data for dynamic indexing via indices_1d. - TensorShape input_shape_i(input_shape); - for (int64 d = 0; d < indices_shape.dims(); ++d) { - input_shape_i.RemoveDim(0); - } - TensorShape flat_shape({indices_shape.num_elements()}); - flat_shape.AppendShape(input_shape_i); - - // output is same as flattened input shape with dim_size(0) = num_segments. - TensorShape out_shape(flat_shape); - out_shape.set_dim(0, num_segments); - - // Slices from the input data are same shape as the input data, except dim 0. - TensorShape slice_shape(flat_shape); - slice_shape.set_dim(0, 1); - TensorShape loop_out_slice_shape(out_shape); - loop_out_slice_shape.set_dim(0, 1); - - // Construct the initial values of the loop-carried variables - // Flatten the indices into 1-D for ease of iteration. - auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()}); - // Flatten the data for ease of indexing via values in indices_1d. - auto data_flat = builder->Reshape(input, flat_shape.dim_sizes()); - - auto init_i = builder->ConstantR0(0); - auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); - - xla::PrimitiveType ptype; - TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); - - std::vector tuple_shapes( - {// The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The flattened input data is loop invariant. - xla::ShapeUtil::MakeShape(ptype, flat_shape.dim_sizes()), - // The scatter indices tensor is loop invariant. - xla::ShapeUtil::MakeShape(xla::S32, {indices_shape.num_elements()}), - // The output data array is updated each loop iteration. - xla::ShapeUtil::MakeShape(ptype, out_shape.dim_sizes())}); - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init = builder->Tuple({init_i, data_flat, indices_1d, init_out}); - - // Construct the while loop condition (i < num_indices) - xla::ComputationBuilder condb(ctx->builder()->client(), - "ScatterAddWhileCond"); - condb.Lt(condb.GetTupleElement( - condb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"), 0), - condb.ConstantR0(indices_shape.num_elements())); - auto cond_status = condb.Build(); - auto cond = cond_status.ConsumeValueOrDie(); - - // Construct the while loop body's function. The implementation of scatter is: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // xi = dynamic-slice(input, i) - // output = dynamic-update-slice(output, xi, index) - xla::ComputationBuilder bodyb(ctx->builder()->client(), - "ScatterAddWhileBody"); - { - auto input_tuple = bodyb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"); - auto i = bodyb.GetTupleElement(input_tuple, 0); - auto data = bodyb.GetTupleElement(input_tuple, 1); - auto idcs = bodyb.GetTupleElement(input_tuple, 2); - auto output = bodyb.GetTupleElement(input_tuple, 3); - - // Index into the data array at i. - auto zero = bodyb.ConstantR1({0}); - std::vector index_vals(flat_shape.dims(), zero); - index_vals[0] = bodyb.Reshape(i, {1}); - auto index = bodyb.ConcatInDim(index_vals, 0); - - auto data_slice = - bodyb.Reshape(bodyb.DynamicSlice(data, index, slice_shape.dim_sizes()), - loop_out_slice_shape.dim_sizes()); - - // Index into the output array. - std::vector out_index_vals(out_shape.dims(), - zero); - out_index_vals[0] = bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1}); - auto out_index = bodyb.ConcatInDim(out_index_vals, 0); - - // Slice the output array, update value, and update the output slice. - auto updated_output = bodyb.DynamicUpdateSlice( - output, - bodyb.Add(data_slice, - bodyb.DynamicSlice(output, out_index, - loop_out_slice_shape.dim_sizes())), - out_index); - - auto ip1 = bodyb.Add(i, bodyb.ConstantR0(1)); - bodyb.Tuple({ip1, data, idcs, updated_output}); - } - auto body_status = bodyb.Build(); - auto body = body_status.ConsumeValueOrDie(); - - auto gather_while = builder->While(cond, body, init); - return builder->GetTupleElement(gather_while, 3); -} - namespace { class UnsortedSegmentSum : public XlaOpKernel { @@ -153,10 +41,10 @@ class UnsortedSegmentSum : public XlaOpKernel { // as data with the first indices.rank dimensions are replaced // by a single dimension with size num_segments. auto data = ctx->Input(0); - auto data_shape = ctx->InputShape(0); + TensorShape data_shape = ctx->InputShape(0); auto indices = ctx->Input(1); - auto indices_shape = ctx->InputShape(1); + TensorShape indices_shape = ctx->InputShape(1); int64 num_segments; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); @@ -174,10 +62,21 @@ class UnsortedSegmentSum : public XlaOpKernel { d, " differs ", data_shape.dim_size(d), " vs. ", indices_shape.dim_size(d))); } - auto result = XlaComputeScatterAddDynamicSlice( - ctx, data, data_shape, indices, indices_shape, num_segments, dtype_, - ctx->builder()); - ctx->SetOutput(0, result); + xla::ComputationBuilder* builder = ctx->builder(); + TensorShape buffer_shape = data_shape; + buffer_shape.RemoveDimRange(0, indices_shape.dims()); + buffer_shape.InsertDim(0, num_segments); + auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), + buffer_shape.dim_sizes()); + + auto combiner = + [](xla::ComputationDataHandle a, xla::ComputationDataHandle b, + xla::ComputationBuilder* builder) { return builder->Add(a, b); }; + + auto result = XlaScatter(buffer, /*updates=*/data, indices, + /*indices_are_vectors=*/false, combiner, builder); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 91c169428c7a88a8d107a97445aeea999946e3e9..6204aa4e27000fddec7f5b82b2198d37956f6aba 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -77,13 +77,14 @@ class StridedSliceOp : public XlaOpKernel { for (int i = 0; i < begin.size(); ++i) { if (strides[i] > 0) { slice_begin.push_back(begin[i]); - slice_end.push_back(end[i]); + slice_end.push_back(std::max(end[i], begin[i])); slice_strides.push_back(strides[i]); } else { // Negative stride: swap begin and end, add 1 because the interval // is semi-open, and mark the dimension to be reversed. slice_begin.push_back(input_shape.dim_size(i) - begin[i] - 1); - slice_end.push_back(input_shape.dim_size(i) - end[i] - 1); + slice_end.push_back(std::max(input_shape.dim_size(i) - end[i] - 1, + input_shape.dim_size(i) - begin[i] - 1)); slice_strides.push_back(-strides[i]); dimensions_to_reverse.push_back(i); } diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 0c5ad9e5255ffc3dfcfb83335060ae833937b3ce..7cb47f908d4ff43f455f1e77c53cd3cc956579ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -60,11 +60,13 @@ XLAJIT_MAKE_UNARY( b->Add(XlaHelpers::One(b, input_type(0)), x)))); // acosh(x) = log(x + sqrt(x^2 - 1)) +// = log(x + sqrt((x+1)*(x-1))) XLAJIT_MAKE_UNARY( Acosh, - b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), - XlaHelpers::One(b, input_type(0))), - XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + b->Log(b->Add(x, + b->Pow(b->Mul(b->Add(x, XlaHelpers::One(b, input_type(0))), + b->Sub(x, XlaHelpers::One(b, input_type(0)))), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); // asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) XLAJIT_MAKE_UNARY( diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index d184f59e01788829d0ba97092c14d36e5188e4e8..488fda74bf7b5c1d66f8d706a1be3cc1fc29a492 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -49,6 +49,25 @@ cc_library( ], ) +cc_library( + name = "scatter", + srcs = ["scatter.cc"], + hdrs = ["scatter.h"], + deps = [ + ":util", + ":while_loop", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:lib", + ], +) + cc_library( name = "triangular_solve", srcs = ["triangular_solve.cc"], @@ -107,6 +126,21 @@ cc_library( ], ) +cc_library( + name = "while_loop", + srcs = ["while_loop.cc"], + hdrs = ["while_loop.h"], + deps = [ + ":util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc new file mode 100644 index 0000000000000000000000000000000000000000..45699233ea8b2a75e3850098250307b95546cc28 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -0,0 +1,192 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +xla::StatusOr XlaScatter( + const xla::ComputationDataHandle& buffer, + const xla::ComputationDataHandle& updates, + const xla::ComputationDataHandle& indices, bool indices_are_vectors, + const std::function& combiner, + xla::ComputationBuilder* builder) { + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, + builder->GetShape(buffer)); + TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, + builder->GetShape(updates)); + TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, + builder->GetShape(indices)); + gtl::ArraySlice indices_dims = + xla::AsInt64Slice(indices_shape->dimensions()); + gtl::ArraySlice buffer_dims = + xla::AsInt64Slice(buffer_shape->dimensions()); + + // If the indices are N-dimensional, the minor dimension of indices contains + // the indices to update. Otherwise the indices are all scalars. + int64 num_index_dims = 1; + if (indices_are_vectors) { + TF_RET_CHECK(!indices_dims.empty()); + num_index_dims = indices_dims.back(); + if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + return errors::InvalidArgument( + "The size of the minor dimension of the indices (shape: ", + xla::ShapeUtil::HumanString(*indices_shape), + ") must be <= the rank of the buffer (shape: ", + xla::ShapeUtil::HumanString(*buffer_shape), ")"); + } + indices_dims.pop_back(); + } + + int64 num_indices = 1; + for (int64 dim : indices_dims) { + num_indices *= dim; + } + + // Degenerate case: nothing to update. Return the buffer unchanged. + if (num_indices == 0) { + return buffer; + } + + // If any of the indexed dimensions are zero in the buffer, the update cannot + // succeed since it updates a slice of size 1. + for (int64 i = 0; i < num_index_dims; ++i) { + if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { + return errors::InvalidArgument( + "Scatter dimension ", i, " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(*buffer_shape)); + } + } + + // Shape of the non-indexed dimensions of the buffer. + std::vector buffer_shape_post_axes( + buffer_dims.begin() + num_index_dims, buffer_dims.end()); + + // Flatten the major dimensions of indices and updates into a single dimension + // for ease of iteration. + std::vector flat_indices_shape({num_indices}); + if (indices_are_vectors) { + flat_indices_shape.push_back(num_index_dims); + } + + std::vector flat_updates_shape({num_indices}); + flat_updates_shape.insert(flat_updates_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + + // Construct the initial values of the loop-carried Tensors. + auto flat_indices = builder->Reshape(indices, flat_indices_shape); + auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto init = {flat_indices, flat_updates, buffer}; + + // Constructs the loop body. The implementation of scatter is essentially: + // for i in range(num_indices): + // index = dynamic-slice(indices, i) + // update = dynamic-slice(updates, i) + // buffer = dynamic-update-slice(buffer, update, index) + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* body_builder) { + auto indices = loop_vars[0]; + auto updates = loop_vars[1]; + auto buffer = loop_vars[2]; + + auto zero_index = body_builder->ConstantLiteral( + xla::Literal::Zero(indices_shape->element_type())); + + // Slice the i-th index from the indices array. + xla::ComputationDataHandle index; + auto indices_offset = body_builder->Reshape(i, {1}); + if (indices_are_vectors) { + indices_offset = body_builder->Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); + + index = body_builder->DynamicSlice(indices, indices_offset, + {1, num_index_dims}); + index = body_builder->Collapse(index, {0, 1}); + } else { + index = body_builder->DynamicSlice(indices, indices_offset, {1}); + } + + // Discard updates with negative indices, since some users expect this. + auto index_in_range = + body_builder->ReduceAll(body_builder->Le(zero_index, index), + body_builder->ConstantR0(true), + xla::CreateScalarAndComputation(body_builder)); + + // Make the index in bounds to prevent implementation defined behavior. + index = body_builder->Max(index, zero_index); + index = body_builder->Pad( + index, zero_index, + xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); + + // Slice the i-th index from the updates array. + auto updates_offset = body_builder->Reshape(i, {1}); + updates_offset = body_builder->Pad( + updates_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); + std::vector flat_updates_slice_shape({1}); + flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + auto update = body_builder->DynamicSlice(updates, updates_offset, + flat_updates_slice_shape); + + // Unflatten the major (iteration) dimensions of the slice to their + // original shape. + std::vector updates_slice_shape(num_index_dims, 1); + updates_slice_shape.insert(updates_slice_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + update = body_builder->Reshape(update, updates_slice_shape); + + // Apply the update to the buffer. If there is a combiner, use it to merge + // the current values with the update. + auto current_value = + body_builder->DynamicSlice(buffer, index, updates_slice_shape); + if (combiner) { + update = combiner(current_value, update, body_builder); + } + // Use the current value instead of the update if the index is out of + // bounds. + update = body_builder->Select(index_in_range, update, current_value); + // Apply the update. + buffer = body_builder->DynamicUpdateSlice(buffer, update, index); + + return std::vector{indices, updates, buffer}; + }; + + TF_ASSIGN_OR_RETURN( + auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), + body_fn, init, "scatter", builder)); + return outputs[2]; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..41e6d3b195ebf90662c7b9b42c53fcb0133ab29e --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Builds an XLA computation that performs a scatter operation on `buffer`, +// returning an updated buffer. +// For each i0, i1, ..., sets +// buffer[indices[i0, i1, ...], ...] := updates[i0, i1, ...] +// +// If `indices_are_vectors` is false, then each index in indices is a scalar, +// and the shape of `indices` must be a prefix of the shape of updates. +// Otherwise, `indices_are_vectors`, then indices are multidimensional and the +// minor dimension of `indices` represents a vector of indices. +// +// If any indices are negative, the corresponding update is discarded. +// +// If a `combiner` is provided, updates are combined with the existing values in +// the buffer using the combiner function. Otherwise, the updates replace the +// existing values. The order of updates is implementation-defined. +xla::StatusOr XlaScatter( + const xla::ComputationDataHandle& buffer, + const xla::ComputationDataHandle& updates, + const xla::ComputationDataHandle& indices, bool indices_are_vectors, + const std::function& combiner, + xla::ComputationBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 9b7492f8cf6e86498d7e2f5d42e42ea978c664d8..f579669bbd852b514e021ce71d635f8ce5e4fe4d 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -57,6 +57,61 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } +xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, + int64 value) { + xla::Literal literal; + switch (type) { + case xla::U8: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::U32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::U64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S8: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::F32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::F64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::C64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::PRED: + LOG(FATAL) << "pred element type is not integral"; + case xla::S16: + case xla::U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; + case xla::F16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; + case xla::TUPLE: + LOG(FATAL) << "tuple element type is not integral"; + case xla::OPAQUE: + LOG(FATAL) << "opaque element type is not integral"; + default: + LOG(FATAL) << "unhandled element type " << type; + } + return builder->ConstantLiteral(literal); +} + xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, gtl::ArraySlice start, gtl::ArraySlice end) { diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 7f93102ee78bec60018814975a0badfeb7874aa6..51f8baaf00bd8fd25baa1a87be8cb0089dfb22b5 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -32,6 +32,11 @@ xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, double value); +// Returns a integer scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, int64 value); + // Performs a slice in the minor dimensions of a Tensor. xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc new file mode 100644 index 0000000000000000000000000000000000000000..86c02ac2e65c12d3527c4022df0cc603e522ef7a --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +xla::StatusOr> XlaWhileLoop( + const LoopConditionFunction& condition_function, + const LoopBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder) { + int arity = initial_values.size(); + std::vector var_shapes; + var_shapes.reserve(arity); + for (const xla::ComputationDataHandle& input : initial_values) { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); + var_shapes.push_back(std::move(*shape)); + } + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + + // Unpacks a tuple into its component parts. + auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, + xla::ComputationBuilder* builder) { + std::vector elements(arity); + for (int i = 0; i < arity; ++i) { + elements[i] = builder->GetTupleElement(tuple, i); + } + return elements; + }; + + // Build the condition. + std::unique_ptr cond_builder = + builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + { + auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + + TF_ASSIGN_OR_RETURN( + auto result, + condition_function(unpack_tuple(parameter, arity, cond_builder.get()), + cond_builder.get())); + TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result)); + } + TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); + + // Build the body. + std::unique_ptr body_builder = + builder->CreateSubBuilder(strings::StrCat(name, "_body")); + { + auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + + TF_ASSIGN_OR_RETURN( + auto result, + body_function(unpack_tuple(parameter, arity, body_builder.get()), + body_builder.get())); + + TF_RET_CHECK(result.size() == initial_values.size()); + body_builder->Tuple(result); + } + TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); + + auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + + return unpack_tuple(outputs, arity, builder); +} + +xla::StatusOr> XlaForEachIndex( + int64 num_iterations, xla::PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder) { + auto while_cond_fn = [&](gtl::ArraySlice values, + xla::ComputationBuilder* cond_builder) + -> xla::StatusOr { + return cond_builder->Lt( + values[0], + IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + }; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::ComputationBuilder* body_builder) + -> xla::StatusOr> { + xla::ComputationDataHandle iteration = values[0]; + + std::vector updated_values; + updated_values.reserve(values.size()); + updated_values.push_back(body_builder->Add( + iteration, + body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + + values.remove_prefix(1); + TF_ASSIGN_OR_RETURN(std::vector body_outputs, + body_function(iteration, values, body_builder)); + updated_values.insert(updated_values.end(), body_outputs.begin(), + body_outputs.end()); + return updated_values; + }; + + std::vector values; + values.reserve(initial_values.size() + 1); + values.push_back( + builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + values.insert(values.end(), initial_values.begin(), initial_values.end()); + + TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, + name, builder)); + values.erase(values.begin(), values.begin() + 1); + return values; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h new file mode 100644 index 0000000000000000000000000000000000000000..2e67a0c99b6deb65fa16ab2dec1727f5cb5fcb92 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Function that builds a loop condition. Takes as input a sequence of input +// values, and returns a boolean value representing if the condition succeeds. +typedef std::function( + gtl::ArraySlice, xla::ComputationBuilder*)> + LoopConditionFunction; + +// Function that builds a loop body. Takes as input a sequence of input values +// and returns a sequence of output values. +typedef std::function>( + gtl::ArraySlice, xla::ComputationBuilder*)> + LoopBodyFunction; + +// Helper function for building an XLA while loop, where the values carried by +// the loop are a tuple of values, e.g., (a, b, c): +// while( +// condition: (a, b, c) -> bool, +// body: (a, b, c) -> (a, b, c) +// init: (a, b, c) +// ) +// 'name' is a descriptive name for the loop. +xla::StatusOr> XlaWhileLoop( + const LoopConditionFunction& condition_function, + const LoopBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder); + +// Builds an XLA loop that repeats a computation `num_iterations` times. +// +// The body function (ForEachIndexBodyFunction) takes as input a pair of +// (current iteration number, loop-carried values), and returns an updated +// vector of the loop-carried values. +typedef std::function>( + xla::ComputationDataHandle, gtl::ArraySlice, + xla::ComputationBuilder*)> + ForEachIndexBodyFunction; + +xla::StatusOr> XlaForEachIndex( + int64 num_iterations, xla::PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index fcbd157c6191655865d5e250fdf71338780bc2a6..2c3cd658e0462368ac0b51938979b7a6815a7574 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -40,20 +40,20 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { return Status::OK(); } -Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, - Tensor* host_tensor) { +Status CopyLiteralToHostTensor(const xla::Literal& literal, + Tensor* host_tensor) { + TF_RET_CHECK(xla::ShapeUtil::IsArray(literal.shape()) && + xla::ShapeUtil::ElementsIn(literal.shape()) == + host_tensor->NumElements()); xla::PrimitiveType primitive_type; - TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(target_type, &primitive_type)); + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(host_tensor->dtype(), &primitive_type)); if (literal.shape().element_type() != primitive_type) { return errors::InvalidArgument( "Cannot convert literal of type ", xla::PrimitiveType_Name(literal.shape().element_type()), - " to tensor of type ", DataTypeString(target_type)); + " to tensor of type ", DataTypeString(host_tensor->dtype())); } - - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); - *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { const void* src_ptr = literal.untyped_data(); @@ -63,4 +63,12 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, return Status::OK(); } +Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, + Tensor* host_tensor) { + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(literal.shape(), &shape)); + *host_tensor = Tensor(target_type, shape); + return CopyLiteralToHostTensor(literal, host_tensor); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/literal_util.h b/tensorflow/compiler/tf2xla/literal_util.h index fe08e83c2391a8b24696961cacfd909d46e49e7d..f283b0236811f8d52e8fe2982a74c11c92cd20d8 100644 --- a/tensorflow/compiler/tf2xla/literal_util.h +++ b/tensorflow/compiler/tf2xla/literal_util.h @@ -29,7 +29,8 @@ namespace tensorflow { // unsupported type. Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); -// Copies 'literal' to 'host_tensor', which is allocated of type . +// Copies 'literal' to freshly allocated 'host_tensor', which is allocated of +// type . // Fails if the literal's primitive type != // DataTypeToPrimitiveType(target_type). Note that is not // derivable from the type of , because multiple tensorflow types map @@ -38,6 +39,12 @@ Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal); Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, Tensor* host_tensor); +// Copies the contents of 'literal' to a previously allocated tensor +// 'host_tensor'. The tensor and the literal must have the same number of +// elements and the same type. +Status CopyLiteralToHostTensor(const xla::Literal& literal, + Tensor* host_tensor); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LITERAL_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/python/BUILD b/tensorflow/compiler/tf2xla/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f0a2ef0651ff6115bd201a3b1c34b3c061a22a3d --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/BUILD @@ -0,0 +1,24 @@ +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [ + "//learning/tfx:__subpackages__", + "//tensorflow:internal", + ], +) + +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_py_clif_cc", +) + +tf_py_clif_cc( + name = "xla_op_registry", + srcs = ["xla_op_registry.clif"], + pyclif_deps = [ + "//tensorflow/core:framework/kernel_def_pyclif", + ], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiler", + ], +) diff --git a/tensorflow/compiler/tf2xla/python/xla_op_registry.clif b/tensorflow/compiler/tf2xla/python/xla_op_registry.clif new file mode 100644 index 0000000000000000000000000000000000000000..e1ee6cc656a314876fc1fabbebe1bee39a6b2831 --- /dev/null +++ b/tensorflow/compiler/tf2xla/python/xla_op_registry.clif @@ -0,0 +1,7 @@ +from "third_party/tensorflow/core/framework/kernel_def_pyclif.h" import * # KernelDef + +from "third_party/tensorflow/compiler/tf2xla/xla_op_registry.h": + namespace `tensorflow`: + def `XlaOpRegistry::DeviceKernels` as + device_kernels(device: str, include_compilation_only_kernels: bool) -> + list diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c5b4ec5b15f90eb43c61cddb7bfd7640fa237a3d..5ec05c4121e059ad2b1307376766a41916fe61ae 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -109,6 +109,12 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) local_flib_runtime_ = local_pflr_->GetFLR(device_->name()); flib_runtime_ = pflr_->GetFLR(device_->name()); + + // The default variable representation shape is the identity function. + if (!options_.variable_representation_shape_fn) { + options_.variable_representation_shape_fn = + [](const TensorShape& shape, DataType type) { return shape; }; + } } XlaCompiler::~XlaCompiler() = default; @@ -153,7 +159,8 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); + opts.set_opt_level(OptimizerOptions::L0); + opts.set_do_common_subexpression_elimination(false); opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); @@ -184,8 +191,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CheckSignature(fbody->arg_types, args), "Signature check failure while compiling: ", function.name()); - std::unique_ptr graph(new Graph(options_.flib_def)); - CopyGraph(*fbody->graph, graph.get()); + std::unique_ptr graph = GetGraph(fbody); // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have @@ -213,15 +219,6 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, *graph); } - // Optimize the graph before running the compiler. - OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, /*shape_map=*/nullptr); - VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); @@ -232,8 +229,8 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, } // Computes the XLA shape for argument 'arg'. -/*static*/ Status XlaCompiler::XLAShapeForArgument( - const XlaCompiler::Argument& arg, xla::Shape* xla_shape) { +Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, + xla::Shape* xla_shape) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), @@ -244,8 +241,12 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, TF_RET_CHECK(arg.initialized); switch (arg.resource_kind) { - case XlaResource::kVariable: - return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaResource::kVariable: { + TensorShape representation_shape = + options_.variable_representation_shape_fn(arg.shape, arg.type); + return TensorShapeToXLAShape(arg.type, representation_shape, + xla_shape); + } case XlaResource::kTensorArray: { if (arg.tensor_array_size < 0) { return errors::InvalidArgument( @@ -319,16 +320,125 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, return Status::OK(); } +// Builds the XLA computation. +// +// `retvals` is the list of retvals produced by _Retval operators, in index +// order. `variable_map` is a map from variable ID numbers to XlaOpContext +// variable states, generated by the symbolic evaluation. +// If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a (input_index, type) pair, where `input_index` is the +// index of a resource variable argument to the computation, and `type` is the +// type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& arg_cores, + const std::vector& retvals, + const std::vector>& resources, + bool return_updated_values_for_all_resources, + xla::ComputationBuilder* builder, xla::Computation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* resource_updates) { + std::vector elems; + elems.reserve(retvals.size()); + for (const XlaExpression& retval : retvals) { + if (!retval.has_constant_value()) { + elems.push_back(retval.handle()); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + + for (const XlaResource* resource : arg_resources) { + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + const int core = arg_cores[resource->arg_num()]; + DCHECK_LT(resource->arg_num(), arg_cores.size()); + bool modified = + resource->value().handle() != resource->initial_value().handle(); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = modified || + grad.second->value().handle() != + grad.second->initial_value().handle() || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::ComputationDataHandle handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Since we can't change the sharding metadata of as this point, + // create a tuple/get-tuple-element combination so that sharding + // assignment will be placed on this value, which will cause the resource + // update to be returned from the same device that provided the resource. + handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + + elems.push_back(handle); + } + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. + builder->Tuple(elems); + builder->ClearOpMetadata(); + + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + return Status::OK(); +} + +} // namespace + // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status BuildArguments(const Graph& graph, - const std::vector& args, - bool use_tuple_arg, xla::ComputationBuilder* builder, - XlaContext* context, std::vector* arg_cores, - std::vector* arg_expressions, - std::vector* input_mapping, - std::vector* input_shapes, - bool is_entry_computation) { +Status XlaCompiler::BuildArguments( + const Graph& graph, const std::vector& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, XlaContext* context, + std::vector* arg_cores, std::vector* arg_expressions, + std::vector* input_mapping, std::vector* input_shapes, + bool is_entry_computation) { arg_expressions->resize(args.size()); *arg_cores = std::vector(args.size(), -1); @@ -383,8 +493,8 @@ Status BuildArguments(const Graph& graph, std::vector arg_shapes(input_mapping->size()); for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument( - args[(*input_mapping)[i]], &arg_shapes[i])); + TF_RETURN_IF_ERROR( + XLAShapeForArgument(args[(*input_mapping)[i]], &arg_shapes[i])); } if (use_tuple_arg) { @@ -413,6 +523,13 @@ Status BuildArguments(const Graph& graph, } } + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata arg_metadata; + arg_metadata.set_op_name("XLA_Args"); + builder->SetOpMetadata(arg_metadata); + // Build parameter handles for non-constant arguments. std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { @@ -451,6 +568,8 @@ Status BuildArguments(const Graph& graph, } } + builder->ClearOpMetadata(); + // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { @@ -481,108 +600,6 @@ Status BuildArguments(const Graph& graph, return Status::OK(); } -// Builds the XLA computation. -// -// `retvals` is the list of retvals produced by _Retval operators, in index -// order. `variable_map` is a map from variable ID numbers to XlaOpContext -// variable states, generated by the symbolic evaluation. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - bool return_updated_values_for_all_resources, - xla::ComputationBuilder* builder, xla::Computation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (const XlaExpression& retval : retvals) { - if (!retval.has_constant_value()) { - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = - resource->value().handle() != resource->initial_value().handle(); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = modified || - grad.second->value().handle() != - grad.second->initial_value().handle() || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::ScopedShardingAssignment assign_sharding( - builder, core == -1 ? tensorflow::gtl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::ComputationDataHandle handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = builder->GetTupleElement(builder->Tuple({handle}), 0); - - elems.push_back(handle); - } - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. - builder->Tuple(elems); - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, string const& name, std::unique_ptr graph, @@ -607,7 +624,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, xla::ComputationBuilder builder(client(), name); XlaContext* context = new XlaContext(this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants); + options.resolve_compile_time_constants, + &options_.variable_representation_shape_fn); core::ScopedUnref context_unref(context); std::vector arg_expressions; diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b86c82c0ab5ce379d35a13043857f459199e2ad2..c4449bc4be06daff856eff70c6d89be6ddbcf0ee 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -29,6 +29,9 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaContext; + // The XlaCompiler class is responsible for compilation of a self-contained // subgraph of a TensorFlow computation using the XLA linear algebra runtime. // It does a symbolic execution of the graph starting from specific input @@ -239,6 +242,12 @@ class XlaCompiler { // for CPU. bool allow_cpu_custom_calls = false; + // If set, the XLA representation of variables represented to XLA as the + // shape given by this shape function. Variables are reshaped to this shape + // on write, and reshaped to their original shape on read. + std::function + variable_representation_shape_fn; + // If not nullptr, populate_resource_manager is called with the // compilation device's resource manager when the compilation // device is created, and can be used to create metadata objects @@ -278,7 +287,7 @@ class XlaCompiler { // Returns the shape of the XLA parameter for an argument 'arg'. // See the class comment for more details about the argument passing // convention. - static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); + Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. @@ -299,6 +308,17 @@ class XlaCompiler { // Returns the optimized graph object in this function body. std::unique_ptr GetGraph(const FunctionBody* fbody); + // Builds XLA computations for each of the arguments to the computation. + // `args` are the arguments to the computation. + Status BuildArguments(const Graph& graph, + const std::vector& args, + bool use_tuple_arg, xla::ComputationBuilder* builder, + XlaContext* context, std::vector* arg_cores, + std::vector* arg_expressions, + std::vector* input_mapping, + std::vector* input_shapes, + bool is_entry_computation); + // Graph compiler needs to know how to get an optimized graph from a function // body. friend class GraphCompiler; diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 65de4dbad75b7fb76a041bc799fc31dc5cb80d74..a18eeacd41808884fac9ec5d617cb0d274ea27d8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/ops/data_flow_ops.h" #include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/resource_variable_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -683,5 +684,128 @@ TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { << status.error_message(); } +// Tests a simple graph that reads and writes a variable. +TEST_F(XlaCompilerTest, Variables) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR1({7, 42}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({-3, 101}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR1({5, 144}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({4, 143}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + +// Tests a simple graph that reads and writes a variable, with a +// variable_representation_shape_fn passed to the compiler that flattens all +// variable tensors to vectors. +TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1); + auto write = ops::AssignAddVariableOp(scope, var, a); + auto read = ops::ReadVariableOp( + scope.WithControlDependencies(std::vector{write}), var, + DT_INT32); + auto read_plus_one = ops::Add(scope, read, ops::Const(scope, 1)); + auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2, 2}); + args[1].kind = XlaCompiler::Argument::kResource; + args[1].resource_kind = XlaResource::kVariable; + args[1].initialized = true; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2, 2}); + + // Compiles the graph. + XlaCompiler::Options options = DefaultOptions(); + options.variable_representation_shape_fn = [](const TensorShape& shape, + DataType type) { + return TensorShape({shape.num_elements()}); + }; + XlaCompiler compiler(options); + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add", + std::move(graph), args, &result)); + + // Tests that the generated computation works. + std::unique_ptr param0_literal = + xla::Literal::CreateR2({{4, 55}, {1, -3}}); + std::unique_ptr param1_literal = + xla::Literal::CreateR1({22, 11, 33, 404}); + std::unique_ptr param0_data = + client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); + std::unique_ptr param1_data = + client_->TransferToServer(*param1_literal).ConsumeValueOrDie(); + + std::unique_ptr actual = + client_ + ->Execute(*result.computation, {param0_data.get(), param1_data.get()}) + .ConsumeValueOrDie(); + std::unique_ptr actual_literal = + client_->Transfer(*actual).ConsumeValueOrDie(); + + std::unique_ptr expected0 = + xla::Literal::CreateR2({{27, 67}, {35, 402}}); + std::unique_ptr expected1 = + xla::Literal::CreateR1({26, 66, 34, 401}); + std::unique_ptr expected_literal = + xla::Literal::MakeTuple({expected0.get(), expected1.get()}); + xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 73878955e3fd54c103c0b07faf7f5ee5bcd84de0..8423921086fec1cf534cf613102fc3839035cb85 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -62,13 +62,16 @@ void XlaContext::set_args(std::vector args) { args_ = std::move(args); } -XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, - bool allow_cpu_custom_calls, - bool resolve_compile_time_constants) +XlaContext::XlaContext( + XlaCompiler* compiler, xla::ComputationBuilder* builder, + bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + const std::function* + variable_representation_shape_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants) {} + resolve_compile_time_constants_(resolve_compile_time_constants), + variable_representation_shape_fn_(variable_representation_shape_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } @@ -115,6 +118,11 @@ Status XlaContext::CreateResource( return Status::OK(); } +TensorShape XlaContext::VariableRepresentationShape(const TensorShape& shape, + DataType type) const { + return (*variable_representation_shape_fn_)(shape, type); +} + const xla::Computation* XlaContext::GetOrCreateMax(const DataType type) { return LookupOrCreate(type, &max_func_, [this, type] { const string type_string = DataTypeString(type); diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index fac0352ae81e24597e1045981ac47a7cd09481da..00fbaba37c542954f690b310a184cff985a05156 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -44,7 +44,9 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants); + bool allow_cpu_custom_calls, bool resolve_compile_time_constants, + const std::function* + variable_representation_shape_fn); // Virtual method defined by ResourceBase. string DebugString() override; @@ -86,6 +88,11 @@ class XlaContext : public ResourceBase { return resources_; } + // Returns the XLA shape to be used to represent a variable of TF `shape` + // and `type`. + TensorShape VariableRepresentationShape(const TensorShape& shape, + DataType type) const; + // Get an XLA lambda to compute Max. This is cached in the // XlaContext since it may be used by multiple Ops. There is a // separate specialization of the computation for each DataType. @@ -133,6 +140,11 @@ class XlaContext : public ResourceBase { // Holds ownership of resources. The resources are not ordered. std::vector> resources_; + // A function that describes how variable shapes should be represented + // in XLA. Variable values will be reshaped to this shape. Must be non-null. + const std::function* + variable_representation_shape_fn_; + // Cache of prebuilt computations indexed by their type. using ComputationMap = std::map; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 77e24162676045b88dc8b62d2c6a4ecc1e738e96..f048662953e20b2a612271e2daeef6e370c4822a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -135,58 +135,9 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationBuilder* b, DataType data_type, int64 value) { - xla::Literal literal; xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::U8: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::U32: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::U64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::S8: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::S32: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::S64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::F32: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::F64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::C64: - literal = std::move(*xla::Literal::CreateR0(value)); - break; - case xla::PRED: - LOG(FATAL) << "pred element type is not integral"; - case xla::S16: - case xla::U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; - case xla::BF16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); - break; - case xla::F16: - literal = std::move( - *xla::Literal::CreateR0(static_cast(value))); - break; - case xla::TUPLE: - LOG(FATAL) << "tuple element type is not integral"; - case xla::OPAQUE: - LOG(FATAL) << "opaque element type is not integral"; - default: - LOG(FATAL) << "unhandled element type " << type; - } - return b->ConstantLiteral(literal); + return ::tensorflow::IntegerLiteral(b, type, value); } xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index ee29158646fa96fe554d089e11d50afb47e3e300..c4bb90d58755f16672ca7c6a6738065be6330485 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -302,10 +302,19 @@ Status XlaOpKernelContext::ReadVariableInput( "Type mismatch for read of variable ", variable->name(), ". Expected ", DataTypeString(type), "; got ", DataTypeString(variable->type())); } - *value = variable->value(); if (shape) { *shape = variable->shape(); } + + XlaContext& xla_context = XlaContext::Get(context_); + TensorShape representation_shape = xla_context.VariableRepresentationShape( + variable->shape(), variable->type()); + if (representation_shape == variable->shape()) { + *value = variable->value(); + } else { + *value = + builder()->Reshape(variable->value(), variable->shape().dim_sizes()); + } return Status::OK(); } @@ -400,8 +409,8 @@ Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { return Status::OK(); } -Status XlaOpKernelContext::AssignVariable( - int input_index, DataType type, const xla::ComputationDataHandle& handle) { +Status XlaOpKernelContext::AssignVariable(int input_index, DataType type, + xla::ComputationDataHandle handle) { TF_RET_CHECK(handle.handle() != 0); const XlaExpression* expression = @@ -419,6 +428,13 @@ Status XlaOpKernelContext::AssignVariable( XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); + + XlaContext& xla_context = XlaContext::Get(context_); + TensorShape representation_shape = + xla_context.VariableRepresentationShape(shape, type); + if (shape != representation_shape) { + handle = builder()->Reshape(handle, representation_shape.dim_sizes()); + } return variable->SetValue(handle); } diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index e1fd0f55c6d2501b4813c90171630a8df567f78a..4e4b97e0cec8d16b9b5686a779b1285906765dbd 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -175,7 +175,7 @@ class XlaOpKernelContext { // variable has been initialized with a different type or with a // different shape. Status AssignVariable(int input_index, DataType type, - const xla::ComputationDataHandle& handle); + xla::ComputationDataHandle handle); // Helper routines for the OP_REQUIRES macros void CtxFailure(const Status& s); diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 0dde6a986c61bdd5b0b2e6d7a16b29ab95be98ab..bbe808595d958346bd55bf8419306bf3de4cd1d0 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -255,6 +255,8 @@ void XlaOpRegistry::RegisterCompilationKernels() { std::vector XlaOpRegistry::DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels) { + // Ensure compilation kernels registered. + RegisterCompilationKernels(); std::vector kernels; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 34e733bc8d80b364cec1783006eba0a5468b55ea..c7cb69215fb051b7f87c3be3b0b419b9c1b8998c 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -372,7 +372,6 @@ tf_cc_test( cc_library( name = "array2d", - srcs = ["array2d.cc"], hdrs = ["array2d.h"], visibility = ["//visibility:public"], deps = [ diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 71aa057cd3a1c273c0e851497a78f94ba37c778e..46ee4e64c9ae7ca111d9d04bedcb74ff02a42386 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -121,6 +121,23 @@ class Array { CHECK(idx == num_elements()); } + // Creates a 2D array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array(std::initializer_list> values) + : Array(ToInt64Vector({values.size(), values.begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + values_[idx] = static_cast(it2); + ++idx; + } + } + CHECK(idx == num_elements()); + } + // Creates a 3D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. Array(InitializerList3D values) @@ -138,6 +155,27 @@ class Array { CHECK(idx == num_elements()); } + // Creates a 3D array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array(std::initializer_list>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + values_[idx] = static_cast(it3); + ++idx; + } + } + } + CHECK(idx == num_elements()); + } + // Creates a 4D array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. Array(InitializerList4D values) @@ -158,6 +196,31 @@ class Array { CHECK(idx == num_elements()); } + // Creates a 4D array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array(std::initializer_list< + std::initializer_list>>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size(), + values.begin()->begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + for (const auto& it4 : it3) { + values_[idx] = static_cast(it4); + ++idx; + } + } + } + } + CHECK(idx == num_elements()); + } + Array(const Array& other) : sizes_(other.sizes_), values_(new T[num_elements()]) { std::copy(&other.values_[0], &other.values_[0] + num_elements(), @@ -185,7 +248,7 @@ class Array { // Fills the array with the sequence i*multiplier for i=0,1,... void FillWithMultiples(const T& multiplier) { for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = i * multiplier; + values_[i] = static_cast(i) * multiplier; } } diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index bb85fbee9b97fd6b9b0bf7223a9b820989dcbfa7..d30e78ecde45cfcfcfdaac6c13c9d87ab5630c57 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -25,6 +25,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -52,6 +53,14 @@ class Array2D : public Array { Array2D(std::initializer_list> values) : Array(values) {} + // Creates an array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array2D(std::initializer_list> values) + : Array(values) {} + Array2D(const Array2D& other) : Array(other) {} int64 n1() const { return this->dim(0); } @@ -86,9 +95,21 @@ class Array2D : public Array { // Returns a linspace-populated Array2D in the range [from, to] (inclusive) // with dimensions n1 x n2. -std::unique_ptr> MakeLinspaceArray2D(float from, float to, - int64 n1, int64 n2); - +template +std::unique_ptr> MakeLinspaceArray2D(double from, double to, + int64 n1, int64 n2) { + auto array = MakeUnique>(n1, n2); + int64 count = n1 * n2; + NativeT step = (count > 1) ? (to - from) / (count - 1) : 0.0f; + auto set = [&array, n1, n2](int64 index, NativeT value) { + (*array)(index / n2, index % n2) = value; + }; + for (int64 i = 0; i < count - 1; ++i) { + set(i, static_cast(from + i * step)); + } + set(count - 1, to); + return array; +} } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_ARRAY2D_H_ diff --git a/tensorflow/compiler/xla/array2d_test.cc b/tensorflow/compiler/xla/array2d_test.cc index c08e42c20ee684dfad8268aa8223440fbfad8a33..93034a719bfbd6724c007059715754677f3f1e62 100644 --- a/tensorflow/compiler/xla/array2d_test.cc +++ b/tensorflow/compiler/xla/array2d_test.cc @@ -63,6 +63,20 @@ TEST(Array2dTest, InitializerListCtor) { EXPECT_EQ(arr(1, 2), 6); } +TEST(Array2dTest, InitializerListCtorHalf) { + Array2D arr = {{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}; + + EXPECT_EQ(arr.n1(), 2); + EXPECT_EQ(arr.n2(), 3); + + EXPECT_EQ(arr(0, 0), static_cast(1)); + EXPECT_EQ(arr(0, 1), static_cast(2)); + EXPECT_EQ(arr(0, 2), static_cast(3)); + EXPECT_EQ(arr(1, 0), static_cast(4)); + EXPECT_EQ(arr(1, 1), static_cast(5)); + EXPECT_EQ(arr(1, 2), static_cast(6)); +} + TEST(Array2dTest, Accessors) { Array2D arr = {{1, 2, 3}, {4, 5, 6}}; diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index a1c5840a5f3874e27043c821ed4684da2fa6c542..e5eb235d45d160d486d1499db665ed14a8509043 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -57,6 +57,16 @@ class Array3D : public Array { values) : Array(values) {} + // Creates an array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array3D( + std::initializer_list>> + values) + : Array(values) {} + int64 n1() const { return this->dim(0); } int64 n2() const { return this->dim(1); } int64 n3() const { return this->dim(2); } diff --git a/tensorflow/compiler/xla/array3d_test.cc b/tensorflow/compiler/xla/array3d_test.cc index 6b5f4b343b2113652758bbd5ce0fc803239c1266..691ff6c03594a98a12e0fdd2151c4c2a2c9c128a 100644 --- a/tensorflow/compiler/xla/array3d_test.cc +++ b/tensorflow/compiler/xla/array3d_test.cc @@ -69,6 +69,29 @@ TEST(Array3dTest, InitializerListCtor) { EXPECT_EQ(arr(2, 3, 1), 24); } +TEST(Array3dTest, InitializerListCtorHalf) { + Array3D arr = { + {{1.0f, 2.0f}, {3.0f, 4.0f}, {5.0f, 6.0f}, {7.0f, 8.0f}}, + {{9.0f, 10.0f}, {11.0f, 12.0f}, {13.0f, 14.0f}, {15.0f, 16.0f}}, + {{17.0f, 18.0f}, {19.0f, 20.0f}, {21.0f, 22.0f}, {23.0f, 24.0f}}}; + + EXPECT_EQ(arr.n1(), 3); + EXPECT_EQ(arr.n2(), 4); + EXPECT_EQ(arr.n3(), 2); + EXPECT_EQ(arr.num_elements(), 24); + + EXPECT_EQ(arr(0, 0, 0), static_cast(1)); + EXPECT_EQ(arr(0, 0, 1), static_cast(2)); + EXPECT_EQ(arr(0, 1, 0), static_cast(3)); + EXPECT_EQ(arr(0, 3, 1), static_cast(8)); + EXPECT_EQ(arr(1, 0, 0), static_cast(9)); + EXPECT_EQ(arr(1, 1, 1), static_cast(12)); + EXPECT_EQ(arr(2, 0, 0), static_cast(17)); + EXPECT_EQ(arr(2, 1, 1), static_cast(20)); + EXPECT_EQ(arr(2, 2, 0), static_cast(21)); + EXPECT_EQ(arr(2, 3, 1), static_cast(24)); +} + TEST(Array3dTest, Fill) { Array3D fullof7(2, 3, 4, 7); for (int64 n1 = 0; n1 < fullof7.n1(); ++n1) { diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index f8b2b2afe5fed9c465c2a1f39308b7f44311b16a..cff70e54bad0116bdd08674b626b3bf99dc89e1f 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -82,6 +82,16 @@ class Array4D : public Array { values) : Array(values) {} + // Creates an array of Eigen::half from the given nested initializer list of + // float values. + template ::value && + std::is_same::value>::type> + Array4D(std::initializer_list>>> + values) + : Array(values) {} + // Numerically-named aliases for the various dimensions. This matches the // dimension names used in array3d. int64 n4() const { return this->dim(3); } diff --git a/tensorflow/compiler/xla/array4d_test.cc b/tensorflow/compiler/xla/array4d_test.cc index 3bc8148c911df0aeade364e4ac2e2ee828bacb53..927733ea1eab43feff643c35535cc6d9ea59ba5a 100644 --- a/tensorflow/compiler/xla/array4d_test.cc +++ b/tensorflow/compiler/xla/array4d_test.cc @@ -97,6 +97,36 @@ TEST(Array3dTest, InitializerListCtor) { EXPECT_EQ(arr(2, 3, 1, 0), 24); } +TEST(Array3dTest, InitializerListCtorHalf) { + Array4D arr = { + {{{1.0f}, {2.0f}}, {{3.0f}, {4.0f}}, {{5.0f}, {6.0f}}, {{7.0f}, {8.0f}}}, + {{{9.0f}, {10.0f}}, + {{11.0f}, {12.0f}}, + {{13.0f}, {14.0f}}, + {{15.0f}, {16.0f}}}, + {{{17.0f}, {18.0f}}, + {{19.0f}, {20.0f}}, + {{21.0f}, {22.0f}}, + {{23.0f}, {24.0f}}}}; + + EXPECT_EQ(arr.n1(), 3); + EXPECT_EQ(arr.n2(), 4); + EXPECT_EQ(arr.n3(), 2); + EXPECT_EQ(arr.n4(), 1); + EXPECT_EQ(arr.num_elements(), 24); + + EXPECT_EQ(arr(0, 0, 0, 0), static_cast(1)); + EXPECT_EQ(arr(0, 0, 1, 0), static_cast(2)); + EXPECT_EQ(arr(0, 1, 0, 0), static_cast(3)); + EXPECT_EQ(arr(0, 3, 1, 0), static_cast(8)); + EXPECT_EQ(arr(1, 0, 0, 0), static_cast(9)); + EXPECT_EQ(arr(1, 1, 1, 0), static_cast(12)); + EXPECT_EQ(arr(2, 0, 0, 0), static_cast(17)); + EXPECT_EQ(arr(2, 1, 1, 0), static_cast(20)); + EXPECT_EQ(arr(2, 2, 0, 0), static_cast(21)); + EXPECT_EQ(arr(2, 3, 1, 0), static_cast(24)); +} + TEST(Array4dTest, Fill) { Array4D fullof7(2, 3, 4, 5, 7); fullof7.Each([](tensorflow::gtl::ArraySlice idx, int* cell) { diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc index 8b9419477479d952126fd831eb44899e7649ca71..e8356c9832d34135f5ffb1a5c7a9d6db6db3a051 100644 --- a/tensorflow/compiler/xla/array_test.cc +++ b/tensorflow/compiler/xla/array_test.cc @@ -60,6 +60,25 @@ TEST(ArrayTest, InitializerListCtor) { EXPECT_EQ(arr(1, 2), 6); } +TEST(ArrayTest, InitializerListCtorHalf) { + Array d2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + EXPECT_EQ(d2.dim(0), 2); + EXPECT_EQ(d2.dim(1), 3); + + Array d3({{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}); + EXPECT_EQ(d3.dim(0), 3); + EXPECT_EQ(d3.dim(1), 2); + EXPECT_EQ(d3.dim(2), 1); + + Array d4( + {{{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}, + {{{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}, {{1.0f}, {4.0f}}}}); + EXPECT_EQ(d4.dim(0), 2); + EXPECT_EQ(d4.dim(1), 3); + EXPECT_EQ(d4.dim(2), 2); + EXPECT_EQ(d4.dim(3), 1); +} + TEST(ArrayTest, IndexingReadWrite) { Array arr({2, 3}); diff --git a/tensorflow/compiler/xla/client/computation.cc b/tensorflow/compiler/xla/client/computation.cc index 4baea8df6e3331200ee52f500fb7b961428e56be..e6c57bda0f0c4cb969939883efebcf3a6d6be381 100644 --- a/tensorflow/compiler/xla/client/computation.cc +++ b/tensorflow/compiler/xla/client/computation.cc @@ -64,4 +64,14 @@ void Computation::ResetWithoutFreeing() { parent_ = nullptr; } +StatusOr Computation::GetProgramShape() const { + GetComputationShapeRequest request; + *request.mutable_computation() = handle_; + GetComputationShapeResponse response; + + TF_RETURN_IF_ERROR(parent_->GetComputationShape(&request, &response)); + + return std::move(*response.mutable_program_shape()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/computation.h b/tensorflow/compiler/xla/client/computation.h index b595172486950bf08b057625d7b2dd97ac9b2278..a53fc9e9cf34704bd08ddb5bf062c1ec1107f5fb 100644 --- a/tensorflow/compiler/xla/client/computation.h +++ b/tensorflow/compiler/xla/client/computation.h @@ -60,6 +60,10 @@ class Computation { // Returns true if this object is a null Computation. bool IsNull() const { return parent_ == nullptr; } + // Returns the "program shape" (parameter and return shapes) for this + // computation. + StatusOr GetProgramShape() const; + private: void ResetWithoutFreeing(); diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 46f2ed4836eda6bf6d5b68f2e29ac6888cd1749b..2a6e02649d15bc9fd47a893c41f9c8a62ac076c6 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -233,6 +233,26 @@ StatusOr> ComputationBuilder::GetShape( return status_or_shape; } +StatusOr ComputationBuilder::GetProgramShape() { + TF_RETURN_IF_ERROR(first_error_); + + GetComputationShapeRequest request; + *request.mutable_computation() = computation_.handle(); + GetComputationShapeResponse response; + + VLOG(2) << "making get-program-shape-request"; + Status status = client_->stub()->GetComputationShape(&request, &response); + VLOG(2) << "done with get-program-shape-request"; + + if (!status.ok()) { + first_error_ = status; + return status; + } + + TF_RET_CHECK(response.has_program_shape()); + return std::move(*response.mutable_program_shape()); +} + ComputationDataHandle ComputationBuilder::CheckShape( const ComputationDataHandle& operand, const Shape& expected_shape) { std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); @@ -769,6 +789,20 @@ ComputationDataHandle ComputationBuilder::CustomCall( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::HostCompute( + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { + OpRequest op_request; + HostComputeRequest* request = op_request.mutable_host_compute_request(); + for (const ComputationDataHandle& operand : operands) { + *request->add_operands() = operand; + } + *request->mutable_shape() = shape; + request->set_channel_name(channel_name); + request->set_cost_estimate_ns(cost_estimate_ns); + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Complex( const ComputationDataHandle& real, const ComputationDataHandle& imag, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -1200,6 +1234,22 @@ ComputationDataHandle ComputationBuilder::While( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::Gather( + const ComputationDataHandle& input, + const ComputationDataHandle& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + OpRequest op_request; + GatherRequest* gather_request = op_request.mutable_gather_request(); + *gather_request->mutable_input() = input; + *gather_request->mutable_gather_indices() = gather_indices; + *gather_request->mutable_dimension_numbers() = dimension_numbers; + for (int64 window_bound : window_bounds) { + gather_request->add_window_bounds(window_bound); + } + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Conditional( const ComputationDataHandle& predicate, const ComputationDataHandle& true_operand, diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index ea4cdb76673b1c99036224bcd754ce4fe1360945..377b6716399ea87b12bd0bd8a9486d4476e3cbf0 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -101,6 +101,9 @@ class ComputationBuilder { StatusOr> GetShape( const ComputationDataHandle& operand); + // Retrieves the (inferred) result for the current computation's shape. + StatusOr GetProgramShape(); + // Checks that the operand has the given expected shape. Returns the operand // if yes, fails with a CHECK error if no. ComputationDataHandle CheckShape(const ComputationDataHandle& operand, @@ -195,9 +198,8 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice new_sizes); // Enqueues an operation onto the computation that collapses the operand, from - // minor to major order, then reshapes it into the shape with the given - // dimension sizes, also from major to minor. Conceptually, this is a limited - // form of "shape casting". + // first to last dimension (C order), then reshapes it to the given dimension + // sizes. Conceptually, this is a limited form of "shape casting". ComputationDataHandle Reshape(const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice new_sizes); @@ -443,6 +445,16 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice operands, const Shape& shape); + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + ComputationDataHandle HostCompute( + tensorflow::gtl::ArraySlice operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given @@ -705,6 +717,13 @@ class ComputationBuilder { const int exponent_bits, const int mantissa_bits); + // Enqueues a Gather node onto the computation. + ComputationDataHandle Gather( + const ComputationDataHandle& input, + const ComputationDataHandle& gather_indices, + const GatherDimensionNumbers& dimension_numbers, + tensorflow::gtl::ArraySlice window_bounds); + // Enqueues a Send node onto the computation, to send the given operand to // a Recv instruction that shares the same channel handle. void Send(const ComputationDataHandle& operand, const ChannelHandle& handle); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index ef98dbb6403beedb0c08ab9a0fc9e7d4ee31ab3b..91396f055fe4a3ecbd436139be9470e2a35e1c63 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -172,7 +172,9 @@ StatusOr> LocalExecutable::Run( std::unique_ptr result, executable_->ExecuteOnStreamWrapper( &service_options, run_options.execution_profile(), arguments)); - return ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()); + + return MakeUnique(std::move(*result), + run_options.allocator()); } StatusOr> LocalExecutable::ExecuteAndDump( diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index e0a9b148b443e90a0c4f3e19660b6234d49eef84..823da43b5ab2e9c8e80181efc993735877a2c363 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1009,6 +1009,49 @@ void Literal::SortSparseElements(const ShapeIndex& shape_index) { piece(shape_index).SortSparseElements(); } +Literal Literal::GetFirstScalarLiteral() const { + CHECK(ShapeUtil::IsArray(shape_)); + CHECK_GT(ShapeUtil::ElementsIn(shape_), 0); + switch (shape_.element_type()) { + case PRED: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 8 bit types. + case S8: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U8: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 16 bit types. + case BF16: + return std::move( + *Literal::CreateR0(GetFirstElement())); + case F16: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S16: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U16: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 32 bit types. + case F32: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S32: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U32: + return std::move(*Literal::CreateR0(GetFirstElement())); + // 64 bit types. + case C64: + return std::move( + *Literal::CreateR0(GetFirstElement())); + case F64: + return std::move(*Literal::CreateR0(GetFirstElement())); + case S64: + return std::move(*Literal::CreateR0(GetFirstElement())); + case U64: + return std::move(*Literal::CreateR0(GetFirstElement())); + default: + LOG(FATAL) << "Unhandled primitive type " << shape_.element_type(); + } +} + void Literal::Piece::SortSparseElements() { switch (subshape().element_type()) { case PRED: @@ -1571,6 +1614,92 @@ bool Literal::IsAllComplex(complex64 value) const { } } +bool Literal::IsAllFirst() const { + for (const auto& pair : pieces_) { + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + // Empty shapes are not all the first element since there is no first + // element. + if (ShapeUtil::HasZeroElements(piece.subshape())) { + return false; + } + auto piece_is_all = [&]() { + switch (piece.subshape().element_type()) { + case PRED: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 8 bit types + case S8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U8: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 16 bit types + case BF16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U16: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 32 bit types + case F32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S32: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + // 64 bit types + case C64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case F64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case S64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + case U64: { + auto data = piece.data(); + return AllElementsEqualValue(data, data[0]); + } + default: + return false; + } + }; + + if (!piece_is_all()) { + return false; + } + } + return true; +} + bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index d996004888ab521790b4c5a10da2a93f0d98d12f..d5ae3fd72322fe243f0156dfbe236b6d62ab8c9d 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -451,6 +451,9 @@ class Literal { template NativeT GetFirstElement() const; + // Returns a literal scalar representing the first element. + Literal GetFirstScalarLiteral() const; + // As Get(), but determines the correct type and converts the value // into text. string GetAsString(tensorflow::gtl::ArraySlice multi_index, @@ -602,6 +605,9 @@ class Literal { // This literal must have a dense layout. bool IsAllComplex(complex64 value) const; + // Literal consists entirely of the first element of the literal. + bool IsAllFirst() const; + // Returns whether this literal is zero at the specified index. This literal // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice indices) const; diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index b3583c2eb75de8297d5e7507430491f119bd4462..ee2f4fe87440428c7364fe2924003c5124f4eaa2 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -501,6 +501,24 @@ TEST_F(LiteralUtilTest, IsAllComplex) { ->IsAllComplex({8.0f, 9.0f})); } +TEST_F(LiteralUtilTest, IsAllFirst) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR1({false, true})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({false, false})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + EXPECT_TRUE(Literal::CreateR1({5, 5, 5, 5})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR1({1, 1, 2})->IsAllFirst()); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAllFirst()); + EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}})->IsAllFirst()); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0(0.0f); auto scalar_one = Literal::CreateR0(1.0f); diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 3b0d8377395ca2a91fb007b784773e6df9c8d6c0..b21ab3044fae7136071f50bdba6e74b799a309d5 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -278,6 +278,12 @@ const Computation& LocalComputation::computation() const { return computation_; } +StatusOr LocalComputation::GetReturnValueShape() const { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + computation_.GetProgramShape()); + return std::move(*program_shape.mutable_result()); +} + LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) : builder_(GetOrCreateLocalClient(), computation_name) {} @@ -303,6 +309,11 @@ std::unique_ptr LocalComputationBuilder::GetShape( return builder_.GetShape(operand).ConsumeValueOrDie(); } +StatusOr LocalComputationBuilder::GetReturnValueShape() { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); + return program_shape.result(); +} + ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { return builder_.Infeed(shape); } @@ -357,6 +368,12 @@ ComputationDataHandle LocalComputationBuilder::Slice( return builder_.Slice(operand, start_indices, limit_indices, strides); } +ComputationDataHandle LocalComputationBuilder::SliceInDim( + const ComputationDataHandle& operand, int64 start_index, int64 limit_index, + int64 stride, int64 dimno) { + return builder_.SliceInDim(operand, start_index, limit_index, stride, dimno); +} + ComputationDataHandle LocalComputationBuilder::DynamicSlice( const ComputationDataHandle& operand, const ComputationDataHandle& start_indices, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 4c6a504f4cd83533185cdadf60ae2c53a0d5e911..a7375c8965e9041226ffee08dab6ffafa25312af 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -102,11 +102,16 @@ class CompiledLocalComputation { class LocalComputation { public: LocalComputation(Computation computation); + StatusOr Compile( const std::vector& argument_shapes, const ExecutableBuildOptions* build_options); + const Computation& computation() const; + // Returns the return-value shape for this computation. + StatusOr GetReturnValueShape() const; + private: Computation computation_; }; @@ -133,6 +138,9 @@ class LocalComputationBuilder { std::unique_ptr GetShape(const ComputationDataHandle& operand); + // Returns the shape of the current return value for the computation. + StatusOr GetReturnValueShape(); + ComputationDataHandle Infeed(const Shape& shape); void Outfeed(const ComputationDataHandle& operand, const Shape& shape, @@ -162,6 +170,10 @@ class LocalComputationBuilder { tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides); + ComputationDataHandle SliceInDim(const ComputationDataHandle& operand, + int64 start_index, int64 limit_index, + int64 stride, int64 dimno); + ComputationDataHandle DynamicSlice( const ComputationDataHandle& operand, const ComputationDataHandle& start_indices, diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i index 114754bde4033a13e217bd6552ebffbde7c3503b..b5354131c94930b75ea66036ddb61ecd3993414f 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.i +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -201,6 +201,15 @@ tensorflow::ImportNumpy(); } } +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + %typemap(out) Status { if (!$1.ok()) { PyErr_SetString( @@ -823,6 +832,21 @@ tensorflow::ImportNumpy(); } Py_DECREF(o); + o = PyObject_GetAttrString($input, "result_shape"); + if (o == nullptr) { + return nullptr; + } + if (o != Py_None) { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat("ExecutableBuildOptions.result_shape could not be created from Python shape value: ", statusor.status().ToString()).c_str()); + Py_DECREF(o); + return NULL; + } + build_options.set_result_layout(statusor.ValueOrDie()); + } + Py_DECREF(o); + $1 = &build_options; } } @@ -843,6 +867,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; %unignore xla::swig::LocalComputation; %unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputation::GetReturnValueShape; %unignore xla::swig::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; %unignore xla::swig::LocalComputationBuilder::Build; @@ -850,6 +875,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; %unignore xla::swig::LocalComputationBuilder::Parameter; %unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; %unignore xla::swig::LocalComputationBuilder::Infeed; %unignore xla::swig::LocalComputationBuilder::Outfeed; %unignore xla::swig::LocalComputationBuilder::ConstantLiteral; @@ -860,6 +886,7 @@ tensorflow::ImportNumpy(); %unignore xla::swig::LocalComputationBuilder::Collapse; %unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; %unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::SliceInDim; %unignore xla::swig::LocalComputationBuilder::DynamicSlice; %unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; %unignore xla::swig::LocalComputationBuilder::ConcatInDim; diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index f8cee5d5665cf95b19e037658b88ffddd5efa511..90cda42f3227c80826ffbf4e5473647c2795544d 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -30,9 +30,9 @@ from tensorflow.compiler.xla import xla_data_pb2 from tensorflow.compiler.xla.python import pywrap_xla as c_api -# Most functions are snake_case for consistency with other modules, -# whereas method names of ComputationBuilder and LocalComputation are -# CamelCase for consistency with XLA. +# Most functions are snake_case for consistency with other modules, whereas +# method names of ComputationBuilder and LocalComputation are CamelCase for +# consistency with XLA. # pylint: disable=invalid-name @@ -123,24 +123,34 @@ _BINARY_OPS = [ 'Pow', ] + XLA_ELEMENT_TYPE_TO_DTYPE = { - xla_data_pb2.F32: np.dtype(np.float32), - xla_data_pb2.F64: np.dtype(np.float64), - xla_data_pb2.S32: np.dtype(np.int32), - xla_data_pb2.S64: np.dtype(np.int64), - xla_data_pb2.U32: np.dtype(np.uint32), - xla_data_pb2.U64: np.dtype(np.uint64), - xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.PRED: np.dtype('bool'), + xla_data_pb2.S8: np.dtype('int8'), + xla_data_pb2.S16: np.dtype('int16'), + xla_data_pb2.S32: np.dtype('int32'), + xla_data_pb2.S64: np.dtype('int64'), + xla_data_pb2.U8: np.dtype('uint8'), + xla_data_pb2.U16: np.dtype('uint16'), + xla_data_pb2.U32: np.dtype('uint32'), + xla_data_pb2.U64: np.dtype('uint64'), + xla_data_pb2.F16: np.dtype('float16'), + xla_data_pb2.F32: np.dtype('float32'), + xla_data_pb2.F64: np.dtype('float64'), + xla_data_pb2.C64: np.dtype('complex64'), xla_data_pb2.TUPLE: np.dtype(np.object), } # Note the conversion on the key. Numpy has a known issue wherein dtype hashing # doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, # when keying by dtype in this dict, we use the string form of dtypes. -DTYPE_TO_XLA_ELEMENT_TYPE = { - str(v): k - for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() -} +DTYPE_TO_XLA_ELEMENT_TYPE = {str(dt): et + for et, dt in XLA_ELEMENT_TYPE_TO_DTYPE.items()} + + +def dtype_to_etype(dtype): + """Convenience function for reading DTYPE_TO_XLA_ELEMENT_TYPE.""" + return DTYPE_TO_XLA_ELEMENT_TYPE[str(np.dtype(dtype))] class LocalBuffer(object): @@ -195,6 +205,12 @@ class Shape(object): self._minor_to_major = minor_to_major self._check_minor_to_major() + def __eq__(self, other): + # pylint: disable=protected-access + return (self.np_dtype == other.np_dtype and + self._dimensions == other._dimensions and + self._minor_to_major == other._minor_to_major) + def __repr__(self): return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, @@ -354,17 +370,44 @@ class LocalComputation(object): # Ensure a reference to C-based destructor for use in __del__. if is_compiled: + assert isinstance(c_local_computation, c_api.CompiledLocalComputation) self._delete = c_api.DeleteCompiledLocalComputation else: + assert isinstance(c_local_computation, c_api.LocalComputation) self._delete = c_api.DeleteLocalComputation def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): + """Compiles an un-compiled local computation. + + Local computations are the result of a "LocalComputationBuild'ing" process + -- they start in uncompiled form, and via a call to Compile() turn into a + compiled local computation. + + Raises: + ValueError: if this is already a compiled local computation. + + Arguments: + argument_shapes: parameter shapes -- they are first laid out by layout_fn + if layout_fn is provided. Otherwise, the default layout for those shapes + will be used. + compile_options: options to use for compilation, includes an optional + laid out result shape for the computation. + layout_fn: lambda that is used to lay out the argument/result shapes. + + Returns: + A newly *compiled* local computation instance. + """ if self.is_compiled: raise ValueError('Attempt to compile a compiled local XLA computation.') + if layout_fn: argument_shapes = [ shape.map_leaves(layout_fn) for shape in argument_shapes ] + result_shape = _wrap_shape(self.c_local_computation.GetReturnValueShape()) + result_shape = result_shape.map_leaves(layout_fn) + compile_options = compile_options or CompileOptions() + compile_options.result_shape = result_shape return LocalComputation( self.c_local_computation.Compile(argument_shapes, compile_options), is_compiled=True) @@ -606,6 +649,9 @@ class ComputationBuilder(object): def GetShape(self, operand): return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + def GetReturnValueShape(self): + return _wrap_shape(self._client.GetReturnValueShape()) + def GetComputationStats(self): raise NotImplementedError() @@ -620,7 +666,7 @@ class ComputationBuilder(object): representing the configuration of the padding operation. Returns: - A ComputationDataHandle representing the added pad op. + A ComputationDataHandle representing the added Pad op. """ if not isinstance(padding_config, xla_data_pb2.PaddingConfig): padding_config = GetPaddingConfigFromTriples(padding_config) @@ -630,7 +676,20 @@ class ComputationBuilder(object): padding_config)) def Reshape(self, operand, dimensions, new_sizes): - """Reshape op.""" + """Enqueues a reshape op onto the computation. + + Args: + operand: ComputationDataHandle representing the array to be reshaped. + dimensions: sequence of integers encoding the order in which dimensions + are collapsed or None, in which case dimensions are flattened in order. + new_sizes: sequence of integers encoding the new dimension sizes (shape). + + Returns: + A ComputationDataHandle representing the added Reshape op. + """ + if dimensions is None: + ndim = len(self.GetShape(operand).dimensions()) + dimensions = tuple(range(ndim)) return _wrap_data_handle( self._client.Reshape( _unwrap_data_handle(operand), dimensions, new_sizes)) @@ -736,11 +795,27 @@ class ComputationBuilder(object): strides = [1] * len(start_indices) return _wrap_data_handle( self._client.Slice( - _unwrap_data_handle(operand), - start_indices, - limit_indices, + _unwrap_data_handle(operand), start_indices, limit_indices, strides)) + def SliceInDim(self, operand, start_index, limit_index, stride, dimno): + """Enqueues a slice-in-dimension operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_index: an integer containing the start index of the slice. + limit_index: an integer containing the end index of the slice. + stride: an integer containing the stride size for the slice. + dimno: an integer indicating the dimension along which to slice. + + Returns: + A ComputationDataHandle representing the added Slice op. + """ + return _wrap_data_handle( + self._client.SliceInDim( + _unwrap_data_handle(operand), start_index, limit_index, stride, + dimno)) + def DynamicSlice(self, operand, start_indices, slice_sizes): """Enqueues a slice op with dynamic start indices onto the computation. diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py index 65720c6ef9ec1cd7a816bcf719960fa803dd45a1..4c16c1f8b07a28d8098e92e27f81a126ed9bdf0c 100644 --- a/tensorflow/compiler/xla/python/xla_client_test.py +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -86,7 +86,8 @@ class ComputationsWithConstantsTest(LocalComputationTest): def testConstantScalarSumF32(self): c = self._NewComputation() - c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) self._ExecuteAndCompareClose(c, expected=4.25) def testConstantScalarSumF64(self): @@ -761,6 +762,23 @@ class SingleOpTest(LocalComputationTest): [3, 2]) self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + def testSliceInDim(self): + c = self._NewComputation() + c.SliceInDim( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=1, + limit_index=2, + stride=1, + dimno=1) + self._ExecuteAndCompareExact(c, expected=[[2], [5], [8]]) + c.SliceInDim( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + start_index=0, + limit_index=3, + stride=2, + dimno=0) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [7, 8, 9]]) + def testDynamicSlice(self): c = self._NewComputation() c.DynamicSlice( @@ -881,6 +899,13 @@ class EmbeddedComputationsTest(LocalComputationTest): c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) return c.Build() + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + def _CreateMulF64By2Computation(self): """Computation (f64) -> f64 that multiplies its parameter by 2.""" c = self._NewComputation("mul_f64_by2") @@ -1021,6 +1046,14 @@ class EmbeddedComputationsTest(LocalComputationTest): self._CreateBinaryDivF64Computation(), [0]) self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + def DISABLED_testMapWithStaticOperands(self): + c = self._NewComputation() + factor = c.ConstantF32Scalar(3.0) + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32ByParamComputation(), [0], + static_operands=[factor]) + self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) + def testSelectAndScatterF32(self): c = self._NewComputation() c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 0f2d0a9e96e20007aa24a22832bdca4f0add372d..e6a6e54927b4752f6e7c8eca1fc0e84301ff0a58 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -43,6 +43,115 @@ filegroup( ]), ) +cc_library( + name = "bfloat16_support", + srcs = ["bfloat16_support.cc"], + hdrs = ["bfloat16_support.h"], + deps = [ + ":hlo", + ], +) + +cc_library( + name = "bfloat16_conversion_folding", + srcs = ["bfloat16_conversion_folding.cc"], + hdrs = ["bfloat16_conversion_folding.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_conversion_folding_test", + srcs = ["bfloat16_conversion_folding_test.cc"], + deps = [ + ":bfloat16_conversion_folding", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "bfloat16_normalization", + srcs = ["bfloat16_normalization.cc"], + hdrs = ["bfloat16_normalization.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_normalization_test", + srcs = ["bfloat16_normalization_test.cc"], + deps = [ + ":bfloat16_normalization", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "bfloat16_propagation", + srcs = ["bfloat16_propagation.cc"], + hdrs = ["bfloat16_propagation.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_dataflow_analysis", + ":hlo_dce", + ":hlo_pass", + ":tuple_simplifier", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_propagation_test", + srcs = ["bfloat16_propagation_test.cc"], + deps = [ + ":bfloat16_propagation", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -70,7 +179,8 @@ tf_cc_test( "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", ], ) @@ -643,6 +753,7 @@ cc_library( hdrs = ["llvm_compiler.h"], deps = [ ":compiler", + "//tensorflow/core:lib_internal", "@llvm//:core", ], ) diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index fb857559f972a220a19b108baa4c441e09b90e1f..5ddd8ec377690bdf47e6d54ae5d419416044a53c 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -122,6 +122,8 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBitcastConvert(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; Status HandleConcatenate(HloInstruction* concatenate) override; @@ -411,6 +413,13 @@ Status AlgebraicSimplifierVisitor::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status AlgebraicSimplifierVisitor::HandleBitcastConvert( + HloInstruction* bitcast) { + // Eliminate bitcast converts between same shape. + ReplaceInstructionIfSameShape(bitcast, bitcast->mutable_operand(0)); + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { // If a copy feeds a copy, make it a single copy. if (copy->operand(0)->opcode() == HloOpcode::kCopy) { @@ -516,6 +525,18 @@ Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { return ReplaceInstruction( constant, BuildTupleConstant(computation_, constant->literal())); } + + // If a literal is all the same element replace it with a scalar broadcast. + if (ShapeUtil::ElementsIn(constant->shape()) > 1 && + constant->literal().IsAllFirst()) { + std::unique_ptr unique_scalar = + MakeUnique(constant->literal().GetFirstScalarLiteral()); + HloInstruction* scalar = computation_->AddInstruction( + HloInstruction::CreateConstant(std::move(unique_scalar))); + return ReplaceWithNewInstruction( + constant, + HloInstruction::CreateBroadcast(constant->shape(), scalar, {})); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 0f08eb3a3267c4b7b04958270a5788fc48d3fa04..667ae01993ebf0feeab89e0b5afaf7c7c8c99ab9 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -162,6 +162,37 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { EXPECT_EQ(root, param0); } +TEST_F(AlgebraicSimplifierTest, ConstantToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({3.14f, 3.14f, 3.14f}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Broadcast(op::Constant())); + EXPECT_EQ(3.14f, root->operand(0)->literal().GetFirstElement()); +} + +TEST_F(AlgebraicSimplifierTest, ConstantNotToBroadcast) { + HloComputation::Builder builder(TestName()); + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({3.14, 3.14, 4}))); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Constant()); +} + // Test that A - 0 is simplified to A TEST_F(AlgebraicSimplifierTest, SubZero) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc index 27ddfd47aa3096afd3e245af1ac3cedd9b48ce4a..84c9db32932becd9b701929b392efa4998d03067 100644 --- a/tensorflow/compiler/xla/service/batchnorm_expander.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -153,6 +153,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormTraining( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; @@ -334,6 +335,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormInference( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; @@ -419,6 +421,7 @@ Status BatchNormExpanderVisitor::HandleBatchNormGrad( std::vector added_instructions; auto add = [&](std::unique_ptr inst) { HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_inst->set_metadata(batch_norm->metadata()); added_instructions.push_back(added_inst); return added_inst; }; diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde990e176ddb57a8e93ecc3c60260b2dbae32a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16ConversionFoldingVisitor( + HloComputation* computation, const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16 + // conversion as output, and folds them to the HLO itself if feasible. + Status TryFoldBF16Conversions(HloInstruction* hlo); + + // Folds the F32 -> BF16 conversions from the HLO's output. + // + // Precondition: all of the HLO's users are F32 -> BF16 conversions. + Status FoldOutputConversions(HloInstruction* hlo); + + // Folds the BF16 -> F32 conversion operand to the HLO. + // + // Precondition: the operand is a F32 -> BF16 conversion. + Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( + HloInstruction* hlo) { + std::vector materialized_users = hlo->users(); + hlo->mutable_shape()->set_element_type(BF16); + for (auto user : materialized_users) { + CHECK_EQ(user->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + changed_ = true; + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( + HloInstruction* hlo, int64 operand_index) { + // The operand is a convert from BF16 to F32. + auto operand = hlo->mutable_operand(operand_index); + CHECK_EQ(operand->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR( + hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0))); + changed_ = true; + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( + HloInstruction* hlo) { + std::vector bf16_to_f32_operands; + bool has_other_f32_operands = false; + for (int64 i = 0; i < hlo->operands().size(); ++i) { + auto operand = hlo->operand(i); + if (operand->shape().element_type() == F32) { + if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->shape().element_type() == BF16 && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Operand is a convert from BF16 to F32 and we support BF16 input + // directly in the current HLO at the operand index. + bf16_to_f32_operands.push_back(i); + } else { + has_other_f32_operands = true; + } + continue; + } + } + + bool fold_output_conversion = hlo->user_count() > 0 && + hlo->shape().element_type() == F32 && + bfloat16_support_->SupportsBF16Output(*hlo) && + hlo != computation_->root_instruction(); + if (fold_output_conversion) { + for (auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + // We should not change the output type if any user is not a conversion + // from F32 to BF16. + fold_output_conversion = false; + break; + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + if (has_other_f32_operands || + (!fold_output_conversion && hlo->shape().element_type() == F32)) { + // Some of the operands/output will remain F32, but we cannot use mixed + // precisions, so we cannot do anything here. + return Status::OK(); + } + } + + if (fold_output_conversion) { + TF_RETURN_IF_ERROR(FoldOutputConversions(hlo)); + } + + for (int64 i : bf16_to_f32_operands) { + TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { + // Do not fold BF16 conversions for instructions related to tuples, entry and + // exit of a computation, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + if (hlo == computation_->root_instruction() && + !bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + // If hlo is the root instruction, we cannot change its output, so folding + // can only happen when it supports mixed precision so that we can change + // its operands. + return Status::OK(); + } + return TryFoldBF16Conversions(hlo); +} + +StatusOr BFloat16ConversionFolding::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h new file mode 100644 index 0000000000000000000000000000000000000000..c9398387098fad84ba28735c30e426fedd9b0cb0 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which folds F32 <-> BF16 conversions to their operands or users, when +// it is supported by the backend. +// +// This pass follows the passed-in backend-specific BF16 support rules, but can +// introduce mixed precision in individual HLOs which breaks the assumption of +// some other HLO passes. So it should be used at the end of the HLO +// optimization pipeline followed by a DCE pass. If other passes are needed +// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the +// changed made by this pass. +class BFloat16ConversionFolding : public HloPassInterface { + public: + explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16ConversionFolding() override = default; + tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + + // Run BF16 conversion folding on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb37759439debf41a305ec7dccaa548e1bf234cd --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -0,0 +1,209 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16ConversionFoldingTest : public HloTestBase { + protected: + bool FoldConversions(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16ConversionFolding fold(&bfloat16_support_); + StatusOr result = fold.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); + EXPECT_EQ(add1->operand(0), add0); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kMultiply, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kSubtract, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a, convert0})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert1); + EXPECT_EQ(gte->shape().element_type(), F32); + EXPECT_EQ(tuple->operand(1), convert0); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc new file mode 100644 index 0000000000000000000000000000000000000000..b032c040e8aff49f9e0fc1ff9a1c1e79ea4bb77f --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -0,0 +1,351 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16NormalizationVisitor(HloComputation* computation, + const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16NormalizationVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts + // conversions between F32 and BF16 to make it supported. + Status HandleInstruction(HloInstruction* hlo); + + // Inserts a conversion HLO that changes the given HLO's output type. + Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, + HloComputation* computation); + + // Changes the output type to the specified type, then inserts a conversion + // to the original type. + Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo, + PrimitiveType to, + HloComputation* computation); + + // Inserts a conversion HLO that changes the given HLO's operand type. + Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx, + PrimitiveType to, + HloComputation* computation); + + // Inserts conversion HLOs to replace the called computations' BF16 + // operands/outputs to F32. + Status ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + bool is_root = computation->root_instruction() == hlo; + std::vector materialized_users = hlo->users(); + // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith. + auto convert = computation->AddInstruction( + HloInstruction::CreateConvert(hlo->shape(), hlo)); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); + } + if (is_root) { + computation->set_root_instruction(convert); + } + convert->mutable_shape()->set_element_type(to); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + auto original_type = hlo->shape().element_type(); + hlo->mutable_shape()->set_element_type(to); + return InsertConvertAfterOutput(hlo, original_type, computation); +} + +Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( + HloInstruction* hlo, int64 operand_idx, PrimitiveType to, + HloComputation* computation) { + auto operand = hlo->mutable_operand(operand_idx); + auto convert = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(operand->shape(), to), operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps) { + std::map cloned_computations; + for (auto& comp : bf16_called_comps) { + auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); + cloned_computations[comp] = cloned; + changed_ = true; + } + hlo->ReplaceCalledComputations([&](HloComputation* comp) { + auto it = cloned_computations.find(comp); + if (it != cloned_computations.end()) { + return it->second; + } + return comp; + }); + for (auto& comp_pair : cloned_computations) { + auto comp = comp_pair.second; + if (comp->root_instruction()->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + InsertConvertAfterOutput(comp->root_instruction(), F32, comp)); + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == BF16) { + // This changes the parameter to F32 then inserts a convert after it. + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(param, F32, comp)); + } + } + } + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape())) { + return HandleInstruction(crs); + } + + std::vector operand_types(crs->operand_count()); + std::vector output_types(crs->operand_count()); + bool has_f32 = false; + bool has_bf16 = false; + bool has_bf16_output = false; + for (int64 i = 0; i < crs->operand_count(); ++i) { + operand_types[i] = crs->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + if (operand_types[i] == F32 || output_types[i] == F32) { + has_f32 = true; + } else if (operand_types[i] == BF16) { + has_bf16 = true; + } + if (output_types[i] == BF16) { + has_bf16 = true; + has_bf16_output = true; + } + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (operand_types[i] != BF16) { + continue; + } + if (bfloat16_support_->SupportsBF16Operand(*crs, i) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + continue; + } + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + has_f32 = true; + } + + if (!has_bf16_output) { + return Status::OK(); + } + + if (bfloat16_support_->SupportsBF16Output(*crs) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + return Status::OK(); + } + + std::vector output_elements(crs->operand_count()); + auto original_shape = crs->shape(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + if (output_types[i] != BF16) { + output_elements[i] = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + continue; + } + subshape->set_element_type(F32); + auto gte = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + output_elements[i] = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(*subshape, BF16), gte)); + } + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple(output_elements)); + + std::vector materialized_users = crs->users(); + // Use the crs' shape temporarily, in order to pass checks in + // ReplaceUseWith. + *tuple->mutable_shape() = crs->shape(); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + } + *tuple->mutable_shape() = original_shape; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { + std::vector bf16_operands; + std::vector f32_operands; + bool has_f32 = false; + bool has_bf16 = false; + + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + f32_operands.push_back(i); + has_f32 = true; + } else if (hlo->operand(i)->shape().element_type() == BF16) { + bf16_operands.push_back(i); + has_bf16 = true; + } + } + + if (hlo->shape().element_type() == F32) { + has_f32 = true; + } else if (hlo->shape().element_type() == BF16) { + has_bf16 = true; + } + + std::vector bf16_called_comps; + for (auto* comp : hlo->called_computations()) { + bool comp_has_bf16 = false; + if (comp->root_instruction()->shape().element_type() == F32) { + has_f32 = true; + } else if (comp->root_instruction()->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == F32) { + has_f32 = true; + } else if (param->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + } + if (comp_has_bf16) { + bf16_called_comps.push_back(comp); + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && + has_f32) { + // Resolve unsupported mixed precision. + // + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i : f32_operands) { + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i : f32_operands) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i : bf16_operands) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + return ConvertCalledComputations(hlo, bf16_called_comps); + } + + for (int i : bf16_operands) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + + if (hlo->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Output(*hlo)) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { + // Do not change instructions related to entry and exit of a computation, + // tuples, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + return HandleInstruction(hlo); +} + +StatusOr BFloat16Normalization::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeComputationPostOrder()) { + if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES(2, + "BFloat16Normalization::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..2a60fe0af3218484acb95e6c69815d551350764c --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not +// support BF16 input/output or mixed precision, according to the passed-in +// backend-specific BF16 support rules. +class BFloat16Normalization : public HloPassInterface { + public: + explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16Normalization() override = default; + tensorflow::StringPiece name() const override { return "bf16-normalization"; } + + // Run BF16 normalization on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +// A pass that unconditionally removes the mixed F32/BF16 uses in HLO +// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike +// BFloat16Normalization, this pass does not use a backend-specific +// BFloat16Support, and does not change HLOs that have BF16 data if they do not +// use mixed precision; it removes mixed precision even if the backend supports +// it. This pass is used to make the HLO module valid for other HLO passes which +// do not support mixed precision. +class BFloat16MixedPrecisionRemoval : public HloPassInterface { + public: + BFloat16MixedPrecisionRemoval() {} + + ~BFloat16MixedPrecisionRemoval() override = default; + + tensorflow::StringPiece name() const override { + return "bf16-mixed-precision-removal"; + } + + // Run mixed precision removal on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override { + BFloat16Normalization normalization(&no_mixed_precision_support_); + return normalization.Run(module); + } + + private: + class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support { + public: + BFloat16SupportForMixedPrecisionRemoval() {} + + ~BFloat16SupportForMixedPrecisionRemoval() override = default; + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return false; + } + } no_mixed_precision_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..66c3085842c4afe7ffc4d5891883e4cce9389d45 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16NormalizationTest : public HloTestBase { + protected: + bool Normalize(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16Normalization normalization(&bfloat16_support_); + StatusOr result = normalization.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16NormalizationTest, NoopIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b)); + + HloInstruction* mul1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), mul1); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b)); + + HloInstruction* sub1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), sub1); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { + Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); + + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); + auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0")); + auto reduce_comp_param1 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1")); + reduce_comp_builder.AddInstruction( + HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, + reduce_comp_param0, reduce_comp_param1)); + + auto module = CreateNewModule(); + auto reduce_computation = + module->AddEmbeddedComputation(reduce_comp_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_input_shape, "a")); + HloInstruction* init = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "init")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + f32_output_shape, input, init, {0}, reduce_computation)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), reduce); + EXPECT_EQ(reduce->called_computations().size(), 1); + EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(0) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(1) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->root_instruction() + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(0), input); + EXPECT_EQ(input->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(1)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc new file mode 100644 index 0000000000000000000000000000000000000000..6145c690b911dd3c74d2677ceb840ae3b86d5309 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc @@ -0,0 +1,447 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_propagation.h" + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +BFloat16Propagation::BFloat16Propagation( + const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + +void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision( + HloInstruction* fusion) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) { + return; + } + + // We are depending on the fusion node itself having already been analyzed + // for whether it can output BF16 and this has been adjusted in the output + // shape, and now we're looking to update the interior of the fusion node to + // match the new output shape, as well as recursively process the whole fusion + // node even if the output shape was not modified. + auto root = fusion->fused_instructions_computation()->root_instruction(); + + // Adjust root's element types according to the fusion's output shape. + ShapeUtil::ForEachMutableSubshape( + root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() != F32) { + return; + } + if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() == + BF16) { + subshape->set_element_type(BF16); + changed_ = true; + VLOG(2) << "Fused root " << root->ToString() << " at shape index " + << index << " changed to BF16 precision for fusion " + << fusion->ToString(); + } + }); + + // Propagate BF16 in the fusion computation. + auto insts = + fusion->fused_instructions_computation()->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false); + } +} + +void BFloat16Propagation::AdjustFusionParameters(HloInstruction* fusion) { + CHECK_EQ(fusion->fused_parameters().size(), fusion->operand_count()); + for (int64 i = 0; i < fusion->operand_count(); ++i) { + auto parameter = fusion->fused_parameter(i); + ShapeUtil::ForEachMutableSubshape( + parameter->mutable_shape(), + [&](Shape* subshape, const ShapeIndex& index) { + if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) { + return; + } + PrimitiveType operand_type = + ShapeUtil::GetSubshape(fusion->operand(i)->shape(), index) + .element_type(); + if (subshape->element_type() == operand_type) { + return; + } + CHECK(operand_type == F32 || operand_type == BF16); + subshape->set_element_type(operand_type); + changed_ = true; + VLOG(2) << "Fused parameter " << parameter->ToString() + << " at shape index " << index + << " adjusted to match operand in fusion " + << fusion->ToString(); + }); + } +} + +bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const { + auto value_set = dataflow_->GetValueSet(&hlo, index); + for (const HloValue* value : value_set.values()) { + if (ContainsKey(values_that_must_be_kept_as_f32_, value)) { + return false; + } + if (value->shape().element_type() == BF16) { + continue; + } + for (const HloUse& use : value->uses()) { + if (use.instruction->opcode() == HloOpcode::kFusion) { + auto fused_parameter = + use.instruction->fused_parameter(use.operand_number); + if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index) + .element_type() != BF16) { + return false; + } + continue; + } + if (bfloat16_support_->EffectiveOperandPrecisionIsBF16( + *use.instruction, use.operand_number)) { + continue; + } + // If the op propagates precision and it outputs a BF16, then it's OK to + // supply BF16 also as the input. In the backward mutation pass, the users + // shapes should have already been processed. + PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID; + if (use.instruction->opcode() == HloOpcode::kTuple || + (use.instruction->opcode() == HloOpcode::kCrossReplicaSum && + ShapeUtil::IsTuple(use.instruction->shape()))) { + user_output_type = ShapeUtil::GetSubshape( + ShapeUtil::GetSubshape(use.instruction->shape(), + {use.operand_number}), + use.operand_index) + .element_type(); + } else { + user_output_type = use.instruction->shape().element_type(); + } + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision( + *use.instruction, use.operand_number) && + user_output_type == BF16) { + continue; + } + return false; + } + } + return true; +} + +void BFloat16Propagation::DetermineAndMutateInstructionPrecision( + HloInstruction* hlo, bool skip_parameters) { + // We handle any fusion computation after the instruction is handled, because + // we need to know a fusion's output shape before propagating inside its fused + // computation. + auto cleaner = tensorflow::gtl::MakeCleanup([this, hlo] { + if (hlo->opcode() == HloOpcode::kFusion) { + DetermineAndMutateFusionComputationPrecision(hlo); + } + }); + + // Do not change precision for instructions related to entry and exit of a + // computation, and control flow, because this pass might break the interfaces + // or assumptions for them. + if (hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional || // + (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) { + return; + } + + // Prevent root instructions from having their output modified by recording + // all F32 output values as needing to stay as F32. + CHECK(hlo->parent() != nullptr); + if (hlo == hlo->parent()->root_instruction()) { + if (!hlo->parent()->IsFusionComputation()) { + ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape, + const ShapeIndex& index) { + if (subshape.element_type() != F32) { + return; + } + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + // Since we use HloValues from the dataflow analysis, this can also + // affect HLO instructions beyond the root, e.g., if the root is a + // Tuple HLO, then its operands are also affected. + values_that_must_be_kept_as_f32_.insert(value); + } + }); + } + return; + } + + if (!ContainsKey(consider_using_bfloat16_, hlo)) { + return; + } + + if (!bfloat16_support_->SupportsBF16Output(*hlo)) { + return; + } + + ShapeUtil::ForEachMutableSubshape( + hlo->mutable_shape(), + [hlo, this](Shape* subshape, const ShapeIndex& index) { + if (subshape->element_type() == F32 && + AllUsersConsumeBF16(*hlo, index)) { + subshape->set_element_type(BF16); + changed_ = true; + VLOG(2) << "HloInstruction output at shape index " << index + << " changed to BF16 precision: " << hlo->ToString(); + } + }); +} + +bool BFloat16Propagation::InstructionIsCandidateForBF16Output( + HloInstruction* hlo) { + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && + hlo->opcode() != HloOpcode::kTuple && + hlo->opcode() != HloOpcode::kGetTupleElement && + hlo->shape().element_type() != BF16) { + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) || + !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) { + return false; + } + } + } + return true; +} + +Status BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers( + HloModule* module) { + std::list computations_topological_order = + module->MakeComputationPostOrder(); + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + auto insts = (*comp_it)->MakeInstructionPostOrder(); + // Do the adjustment on each instruction in the computation in reverse + // topological order. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + auto adjust_buffer = [this, hlo](Shape* subshape, + const ShapeIndex& index) { + if (subshape->element_type() != F32 && + subshape->element_type() != BF16) { + return; + } + PrimitiveType type = BF16; + for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) { + if (value->shape().element_type() == BF16) { + continue; + } + CHECK_EQ(value->shape().element_type(), F32); + type = F32; + break; + } + // It's possible that a user has been changed from BF16 to F32 + // during this final adjustment pass, so we need to check + // AllUsersConsumeBF16() again. + if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) { + type = F32; + } + if (type == F32) { + for (const auto* value : + dataflow_->GetValueSet(hlo, index).values()) { + // We rely on the fact that this adjustment works in reverse + // topological order. Adding the value to + // values_that_must_be_kept_as_f32_ will ensure the correctness + // of the adjustment for HLOs that will be processed later. + values_that_must_be_kept_as_f32_.insert(value); + } + } + subshape->set_element_type(type); + }; + ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_buffer); + } + // Now adjust parameters of fusions inside this computation. + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + if (hlo->opcode() == HloOpcode::kFusion) { + AdjustFusionParameters(hlo); + } + } + } + + // We could have changed a fusion computation's root shape to have a different + // precision than the fusion node's output, if the fusion root does not + // define a buffer (e.g., a tuple). Now we add conversions after such fusion + // roots to make them match the fusion output. If the fusion output is a + // (possibly nested) tuple, we first create get-tuple-elements, then convert + // the unmatching leaf nodes, and finally create a new tuple as the fusion + // computation's root. If tuples and get-tuple-elements are created, we will + // run tuple simplifier and dead code elimination at the end (dead code is not + // allowed in fusion computation). E.g., + // + // (1) (2) (3) + // a b a b a b + // |\ | |\ | |\ | + // \ add -> |add -> | add + // \ | \ | convert | + // tuple tuple \ | + // / \ tuple + // gte gte + // | | + // convert | + // \ / + // tuple + // (1) a is F32 but tuple is BF16 + // (2) after adding conversion + // (3) after tuple simplifier and DCE. + bool needs_tuple_simplifier = false; + for (auto computation : computations_topological_order) { + auto insts = computation->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + auto hlo = *inst_it; + if (hlo->opcode() != HloOpcode::kFusion) { + continue; + } + auto fusion_computation = hlo->fused_instructions_computation(); + auto fusion_root = fusion_computation->root_instruction(); + if (ShapeUtil::Compatible(fusion_root->shape(), hlo->shape())) { + continue; + } + ShapeTree converted_outputs(hlo->shape()); + // Iterate through nodes in the shape tree in pre-order and initialize + // each non-root node with a corresponding get-tuple-element. For a leaf + // node, if its shape does not match the fusion output, create a + // conversion node to overwrite the node value. + for (auto it = converted_outputs.begin(); it != converted_outputs.end(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (output_index.empty()) { + output = fusion_root; + } else { + ShapeIndex parent_index = output_index; + parent_index.pop_back(); + output = fusion_computation->AddInstruction( + HloInstruction::CreateGetTupleElement( + subshape, converted_outputs.element(parent_index), + output_index.back())); + } + if (ShapeUtil::IsTuple(subshape)) { + continue; + } + if (!ShapeUtil::Compatible( + subshape, + ShapeUtil::GetSubshape(fusion_root->shape(), output_index))) { + output = fusion_computation->AddInstruction( + HloInstruction::CreateConvert(subshape, output)); + } + } + // Iterate through nodes in the shape tree in reverse pre-order and create + // a tuple instruction for each non-leaf node where the elements are the + // values of its child nodes. + for (auto it = converted_outputs.rbegin(); it != converted_outputs.rend(); + ++it) { + ShapeIndex output_index = it->first; + HloInstruction*& output = it->second; + const Shape& subshape = + ShapeUtil::GetSubshape(hlo->shape(), output_index); + if (!ShapeUtil::IsTuple(subshape)) { + continue; + } + std::vector elements( + ShapeUtil::TupleElementCount(subshape)); + ShapeIndex child_index = output_index; + for (int64 i = 0; i < elements.size(); ++i) { + child_index.push_back(i); + elements[i] = converted_outputs.element(child_index); + child_index.pop_back(); + } + output = fusion_computation->AddInstruction( + HloInstruction::CreateTuple(elements)); + } + fusion_computation->set_root_instruction(converted_outputs.element({})); + needs_tuple_simplifier |= ShapeUtil::IsTuple(hlo->shape()); + } + } + if (needs_tuple_simplifier) { + TupleSimplifier tuple_simplifier; + TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + return Status::OK(); +} + +// The algorithm first does a forward pass (parameters to root) to determine a +// set of instructions to consider using bfloat16, then does a backward pass to +// determine the precisions of those instructions according to the need of +// their users. +StatusOr BFloat16Propagation::Run(HloModule* module) { + TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module)); + + std::list computations_topological_order = + module->MakeComputationPostOrder(); + // The first step is a forward pass (parameters to root), where we determine + // the potential candidate instructions to use bfloat16 in the outputs that + // are not likely to cause overhead from extra explicit conversions. This is + // done forwardly because we determine whether an HLO is a candidate partially + // based on whether its operands are candidates. + for (auto computation : computations_topological_order) { + for (auto inst : computation->MakeInstructionPostOrder()) { + if (InstructionIsCandidateForBF16Output(inst)) { + consider_using_bfloat16_.insert(inst); + } + } + } + + // The second step is a backward pass (root to parameters), where we modify + // the precisions of the instructions identified in the first step when + // feasible. This is done backwardly because we determine the precision of an + // HLO's output based on how it is later used. + // + // The precision of an instruction is determined by its users, so we do the + // propagation in reverse topological order. + for (auto comp_it = computations_topological_order.rbegin(); + comp_it != computations_topological_order.rend(); ++comp_it) { + if ((*comp_it)->IsFusionComputation()) { + // Fusion computations are handled when visiting the fusion instruction. + continue; + } + auto insts = (*comp_it)->MakeInstructionPostOrder(); + for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) { + DetermineAndMutateInstructionPrecision(*inst_it, + /*skip_parameters=*/true); + } + } + + if (!changed_) { + return false; + } + + // It's possible that an instruction does not define a buffer, but the + // defining instruction's shape has changed. So we need to adjust the output + // shapes of instructions according to the HLO values they refer to. + TF_RETURN_IF_ERROR(ResolveInconsistencyOfAliasingBuffers(module)); + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h new file mode 100644 index 0000000000000000000000000000000000000000..ccf77d7b4eb6bd7b76b1b6743bd724f42c141f08 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// HLO pass which reduces the precision of some HLO instructions to BF16 +// according to the backend-specific BFloat16Support rule provided by the +// caller. +// +// This pass can be used to reduce instruction precision without affecting the +// numerical accuracy of the module, i.e., the final output of the module would +// be bitwise identical to that without this pass; this is possible if the +// backend already reduces precision to BF16 on some HLO instructions. +// +// This pass will not modify the signature of any non-fusion computation. +// +// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs, +// which has two issues: +// +// 1) It does not guarantee to respect the passed-in BFloat16Support +// specification in terms of mixed precision, so the backend may not support an +// HLO that has mixed precision produced by this pass. To address this issue, +// run BFloat16Normalization with the same BFloat16Support after this pass. +// +// 2) In general, mixed precision may break the assumptions of some other HLO +// passes even if the specific backend supports the individual HLOs. Such +// assumptions include that there are no HLOs using mixed precision, or that the +// precision of an HLO's output is determined by its inputs. It should be used +// at the end of the HLO optimization pipeline but before +// BFloat16ConversionFolding. If other passes are needed after this pass, run +// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this +// pass. +class BFloat16Propagation : public HloPassInterface { + public: + explicit BFloat16Propagation(const BFloat16Support* bfloat16_support); + + ~BFloat16Propagation() override = default; + + tensorflow::StringPiece name() const override { + return "bfloat16-propagation"; + } + + // Runs the pass on the given module. Returns whether the module was changed + // (precision reductions were added). + StatusOr Run(HloModule* module) override; + + private: + // *************************** + // Function called and state produced by the forward analysis pass (from + // parameters to root) that determines the candidate HLOs to use BF16 outputs. + + // Determines whether we should consider changing the precision of the given + // instruction in the forward pass. + bool InstructionIsCandidateForBF16Output(HloInstruction* hlo); + + // The set of instructions to consider using bfloat16, computed in the forward + // pass. + tensorflow::gtl::FlatSet consider_using_bfloat16_; + + // *************************** + // Functions called and state produced by the backward mutation pass (from + // root to parameters). + + // Determines the precision for the given instruction in the mutation pass. + void DetermineAndMutateInstructionPrecision(HloInstruction* hlo, + bool skip_parameters); + + // Special handling in the mutation pass for fusion computations. + void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion); + + // *************************** + // Functions called by the final inconsistency resolving pass. + + // Adjusts the output shapes of HloInstructions such that if two + // HloInstructions have aliasing buffers in their outputs, they must have the + // same precision. + Status ResolveInconsistencyOfAliasingBuffers(HloModule* module); + + // Makes the fusion parameters match the precision of the actual parameters + // passed to the fusion node. + void AdjustFusionParameters(HloInstruction* fusion); + + // *************************** + // Functions called and state used by two or more passes. + + // Returns whether all uses of the given HloInstruction can consume BF16 + // input. + bool AllUsersConsumeBF16(const HloInstruction& hlo, + const ShapeIndex& index) const; + + // The set of F32 HLO values that must be kept in F32. + tensorflow::gtl::FlatSet values_that_must_be_kept_as_f32_; + + // *************************** + // State used by both passes. + const BFloat16Support* bfloat16_support_; + std::unique_ptr dataflow_; + + bool changed_ = false; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2047e2053a1a819a2d534f34fc4ba2f8768dc861 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -0,0 +1,393 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_propagation.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +// A class specifying the BF16 support used to test the propagation pass. It +// specifies that BF16 and mixed precision are supported in all HloInstructions, +// and that kDot reduces its operands precision to BF16. +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return true; + } + + bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const override { + return hlo.opcode() == HloOpcode::kDot; + } +}; + +class BFloat16PropagationTest : public HloTestBase { + protected: + // Runs the propagation pass on the given module, and returns whether the + // module is changed after this pass. + bool PropagatePrecision(HloModule* module) { + TestBFloat16Support bfloat16_support; + BFloat16Propagation propagation(&bfloat16_support); + StatusOr result = propagation.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } + + // Returns whether the given HloInstruction's output element type is BF16 or + // the only use of it is converting to BF16. + bool OutputsBF16(const HloInstruction* inst) { + if (inst->shape().element_type() == BF16) { + return true; + } + return inst->user_count() == 1 && + inst->users()[0]->opcode() == HloOpcode::kConvert && + inst->users()[0]->shape().element_type() == BF16; + } +}; + +// Tests that BF16 can propagate through select over non-tuple buffers, but not +// through add where reducing operand precision can affect the result. +TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* c = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b)); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b)); + HloInstruction* sel = builder.AddInstruction( + HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(sel)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(a)); + EXPECT_FALSE(OutputsBF16(b)); + EXPECT_FALSE(OutputsBF16(c)); +} + +// Tests that BF16 can be propagated through nested tuples. +TEST_F(BFloat16PropagationTest, PropagateThroughTuples) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); + HloInstruction* add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); + + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose})); + + HloInstruction* lhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1)); + HloInstruction* rhs = + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + add0->shape(), + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + tuple0->shape(), tuple1, 0)), + 0)); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + + HloInstruction* output_tuple = + builder.AddInstruction(HloInstruction::CreateTuple({dot, add2})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), output_tuple); + EXPECT_TRUE(OutputsBF16(xpose)); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(add2)); +} + +// Tests that even if an instruction does not define a buffer in its output, its +// shape must match the defining instruction. +TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a)); + + HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0})); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add0, add1})); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1)); + + // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1. + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(add0)); + EXPECT_TRUE(OutputsBF16(add1)); + EXPECT_TRUE(OutputsBF16(lhs)); + // rhs is a get-tuple-element, which does not define a buffer, but its shape + // should also be adjusted accordingly. + EXPECT_TRUE(OutputsBF16(rhs)); +} + +// Tests that a non-fusion computation's root should not be changed. +TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* a = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b)); + + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({add, dot})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), tuple); + EXPECT_FALSE(OutputsBF16(add)); +} + +// Tests that BF16 is propagated properly through fused computations. +TEST_F(BFloat16PropagationTest, PropagateThroughFusion) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f0 = HloComputation::Builder("fusion0"); + HloInstruction* a_f0 = + builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f0 = + builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* tuple_f0 = + builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0})); + auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build()); + auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f0)); + + auto builder_f1 = HloComputation::Builder("fusion1"); + HloInstruction* p_f1 = builder_f1.AddInstruction( + HloInstruction::CreateParameter(0, tuple_f0->shape(), "param")); + HloInstruction* a_f1 = builder_f1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p_f1, 0)); + HloInstruction* b_f1 = builder_f1.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, p_f1, 1)); + HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1)); + auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build()); + auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion( + dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), fusion1); + EXPECT_TRUE(OutputsBF16(add)); + EXPECT_TRUE(OutputsBF16(a_f0)); + EXPECT_TRUE(OutputsBF16(b_f0)); + EXPECT_TRUE(OutputsBF16(a_f1)); + EXPECT_TRUE(OutputsBF16(b_f1)); +} + +// Tests that if 1) the root instruction of a fusion is a tuple, 2) the fusion +// outputs are only used by a dot, and 3) one element of the tuple is used by +// an add in the fusion computation, then the propagation pass should create a +// convert in the fusion computation to keep the add's operand in F32 but change +// the fusion output to BF16. E.g., the following fusion computation +// (F32, F32) fusion_computation(F32 a, F32 b) +// = tuple(F32 a, F32 add(F32 a, F32 b)) +// will be changed to +// (BF16, BF16) fusion_computation(F32 a, F32 b) +// = tuple(BF16 convert(a), BF16 add(F32 a, F32 b)) +TEST_F(BFloat16PropagationTest, ConvertTupleFusionElementIfUsedByAdd) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + + auto builder_f = HloComputation::Builder("fusion0"); + HloInstruction* a_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(0, shape, "a")); + HloInstruction* b_f = + builder_f.AddInstruction(HloInstruction::CreateParameter(1, shape, "b")); + HloInstruction* add_f = builder_f.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a_f, b_f)); + HloInstruction* tuple_f = + builder_f.AddInstruction(HloInstruction::CreateTuple({a_f, add_f})); + auto comp_f = module->AddEmbeddedComputation(builder_f.Build()); + auto fusion = builder.AddInstruction(HloInstruction::CreateFusion( + tuple_f->shape(), HloInstruction::FusionKind::kCustom, {add, add}, + comp_f)); + + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, fusion, 1)); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, gte0, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_TRUE(OutputsBF16(gte0)); + EXPECT_TRUE(OutputsBF16(gte1)); + EXPECT_FALSE(OutputsBF16(a_f)); + EXPECT_FALSE(OutputsBF16(b_f)); + EXPECT_TRUE(OutputsBF16(add_f)); + auto new_fusion_root = comp_f->root_instruction(); + EXPECT_EQ(new_fusion_root->opcode(), HloOpcode::kTuple); + EXPECT_EQ(new_fusion_root->operand(1), add_f); + EXPECT_EQ(new_fusion_root->operand(0)->opcode(), HloOpcode::kConvert); + EXPECT_TRUE(OutputsBF16(new_fusion_root->operand(0))); +} + +// A select over tuples does not define the leaf buffers, so the types in +// on_true and on_false must match, so that as long as one of them is F32, the +// other must be F32 as well. +TEST_F(BFloat16PropagationTest, SelectOverTuples) { + auto module = CreateNewModule(); + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param")); + HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(PRED, {}), "pred")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param)); + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param)); + HloInstruction* tuple0 = + builder.AddInstruction(HloInstruction::CreateTuple({param, add0})); + HloInstruction* tuple1 = + builder.AddInstruction(HloInstruction::CreateTuple({param, add1})); + HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary( + tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1)); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, sel, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, sel, 1)); + HloInstruction* xpose = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0})); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), dot); + EXPECT_FALSE(OutputsBF16(add0)); + EXPECT_FALSE(OutputsBF16(add1)); + EXPECT_FALSE(OutputsBF16(gte0)); + EXPECT_FALSE(OutputsBF16(gte1)); + EXPECT_TRUE(OutputsBF16(xpose)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..07b4b14b5ec1bdbc01345091105df69368b0b2fb --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -0,0 +1,112 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + CHECK_EQ(operand_index, 0); + return hlo.operand(0)->shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + return hlo.shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + default: + break; + } + return false; +} + +/* static */ +bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index) { + switch (hlo.opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kBroadcast: + case HloOpcode::kClamp: + case HloOpcode::kConcatenate: + case HloOpcode::kConvert: + case HloOpcode::kCopy: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return true; + case HloOpcode::kDynamicSlice: + return operand_index == 0; + case HloOpcode::kDynamicUpdateSlice: + return operand_index == 0 || operand_index == 1; + case HloOpcode::kSelect: + return operand_index == 1 || operand_index == 2; + default: + break; + } + return false; +} + +bool BFloat16Support::EffectiveOperandPrecisionIsBF16( + const HloInstruction& hlo, int64 operand_index) const { + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h new file mode 100644 index 0000000000000000000000000000000000000000..82c2745f444e4f9c544c78cb36dafc11f678518a --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +class BFloat16Support { + public: + BFloat16Support() {} + virtual ~BFloat16Support() {} + + // Returns whether the backend supports BF16 operand for the HLO instruction + // at the given index. + virtual bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const; + + // Returns whether the backend supports BF16 output for the HLO instruction. + virtual bool SupportsBF16Output(const HloInstruction& hlo) const; + + // Returns whether the backend support mixed precision: the operands, output, + // and parameters/output of the called computations can have different + // precisions (BF16 and F32). + virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const; + + // Returns whether the given HLO preserves its BF16 operand precision at the + // given index, so even if the output is F32, elements in the output that + // depend on the BF16 operand will still have BF16 effective precision even if + // they have F32 format. Similarly, this also means if the output is BF16 then + // increasing the operand precision from BF16 to F32 will not change the + // output. This typically includes HLOs that pass elements from the operand to + // the output without arithmetic operations. + static bool EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index); + + // Returns if the backend only uses BF16 precision for the operand at the + // specified index, even if the operand is F32. + virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index f0a9de5f94634680d84278a2d2ca6d8c50d1e355..d44d3d71d9f28de0fd38f0c1c3aac3cf7418255e 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -45,6 +45,185 @@ using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; + +namespace { + +template +string ColocatedBufferSetsToString(const T& container, const char* title) { + string result; + StrAppend(&result, title, "\n"); + for (const auto& it : container) { + StrAppend(&result, "\t", it->ToString(), "\n"); + } + return result; +} + +// Walk the call graph of the HLO module and place each computation into either +// thread_local_computations or global_computations depending upon whether the +// computation requires thread-local allocations or global allocations. The +// elements in thread_local_computations and global_computations are in post +// order (if computation A has an instruction which calls computation B, then A +// will appear after B in the vector). +Status GatherComputationsByAllocationType( + const HloModule* module, + std::vector* thread_local_computations, + std::vector* global_computations) { + // Create a worklist of computations paired with whether the allocation must + // be thread-local. + std::deque> worklist; + worklist.push_back(std::make_pair(module->entry_computation(), + /*is_thread_local*/ false)); + + // Sets for quickly checking membership. Computations are returned in vectors + // for stable iteration. + FlatSet thread_local_set; + FlatSet global_set; + + while (!worklist.empty()) { + auto worklist_front = worklist.front(); + worklist.pop_front(); + const HloComputation* computation = worklist_front.first; + bool is_thread_local = worklist_front.second; + bool in_thread_local_set = thread_local_set.count(computation) > 0; + bool in_global_set = global_set.count(computation) > 0; + + // If the computation has already been added to the respective set, then + // nothing to do. + if ((is_thread_local && in_thread_local_set) || + (!is_thread_local && in_global_set)) { + continue; + } + + // If the computation has already been added to the other set this is an + // error condition because the global call to the computation (eg, + // while/call) may return a reference to one of the thread-local buffers to + // the calling computation which will become a dangling reference when the + // thread-local is deallocated with the call return. + if ((is_thread_local && in_global_set) || + (!is_thread_local && in_thread_local_set)) { + return InvalidArgument( + "computation %s has conflicting allocation requirements (global " + "and thread-local)", + computation->name().c_str()); + } + + if (is_thread_local) { + thread_local_set.insert(computation); + } else { + global_set.insert(computation); + } + + for (auto* instruction : computation->instructions()) { + for (HloComputation* subcomputation : + instruction->called_computations()) { + switch (instruction->opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kWhile: + // Call and while must be called from a computation with global + // allocations as they may return references to buffers inside the + // called computation which cannot be thread-local. + if (is_thread_local) { + return InvalidArgument( + "computation %s cannot contain call/while op because it " + "requires thread-local buffer allocations", + computation->name().c_str()); + } + worklist.push_back(std::make_pair(subcomputation, + false)); // Not thread local. + break; + case HloOpcode::kMap: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kFusion: + // Map/reduce etc computations are always thread-local. + worklist.push_back(std::make_pair(subcomputation, + true)); // Thread local. + break; + default: + return InternalError( + "Unexpected calling opcode: %s", + HloOpcodeString(instruction->opcode()).c_str()); + } + } + } + } + + // Add the computations to the vectors in post order. + for (auto* computation : module->MakeComputationPostOrder()) { + if (thread_local_set.count(computation) > 0) { + thread_local_computations->push_back(computation); + } else if (global_set.count(computation) > 0) { + global_computations->push_back(computation); + } + // If the computation is not reachable from the entry computation, then it + // will not appear in either thread_local_set or global_set. We don't bother + // assigning buffers for these. + } + return Status::OK(); +} + +// Checks that points-to set of 'instruction' is unambiguous and distinct +// (ensured by CopyInsertion), then adds the buffer from the points-to set at +// 'index' to 'colocated_set'. +const LogicalBuffer* AddBufferToColocatedSet( + const HloInstruction* instruction, const ShapeIndex& index, + const TuplePointsToAnalysis& points_to_analysis, + std::vector* colocated_set) { + // CopyInsertion ensures root points-to set is unambiguous and distinct. + const auto& points_to = points_to_analysis.GetPointsToSet(instruction); + DCHECK(!points_to.IsAmbiguous()); + colocated_set->push_back(points_to.element(index)[0]); + return colocated_set->back(); +} + +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + +} // namespace size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); @@ -93,6 +272,9 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto.set_color(color_.value()); if (is_entry_computation_parameter_) { proto.set_is_entry_computation_parameter(true); + for (int64 idx : param_shape_index()) { + proto.add_parameter_shape_index(idx); + } proto.set_parameter_number(parameter_number_); } proto.set_maybe_live_out(maybe_live_out_); @@ -112,25 +294,24 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - tensorflow::strings::StrAppend( - &output, tensorflow::strings::Printf("allocation %lld: %p, size %lld", - index_, this, size())); + Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); if (color().value() != 0) { - tensorflow::strings::StrAppend(&output, ", color ", color().value()); + StrAppend(&output, ", color ", color().value()); } if (is_entry_computation_parameter()) { - tensorflow::strings::StrAppend(&output, ", parameter ", parameter_number()); + StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ", + param_shape_index().ToString()); } if (is_thread_local()) { - tensorflow::strings::StrAppend(&output, ", thread-local"); + StrAppend(&output, ", thread-local"); } if (maybe_live_out()) { - tensorflow::strings::StrAppend(&output, ", maybe-live-out"); + StrAppend(&output, ", maybe-live-out"); } if (IsPreallocatedTempBuffer()) { - tensorflow::strings::StrAppend(&output, ", preallocated-temp"); + StrAppend(&output, ", preallocated-temp"); } - tensorflow::strings::StrAppend(&output, ":\n"); + StrAppend(&output, ":\n"); // Dump the assigned buffers ordered by id. std::vector sorted_buffers; for (const auto& buffer_offset_size : assigned_buffers_) { @@ -142,12 +323,11 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - tensorflow::strings::StrAppend( - &output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, + tensorflow::strings::Printf( + " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); } return output; } @@ -520,116 +700,6 @@ BufferAssignmentProto BufferAssignment::ToProto() const { return proto; } -namespace { - -// Walk the call graph of the HLO module and place each computation into either -// thread_local_computations or global_computations depending upon whether the -// computation requires thread-local allocations or global allocations. The -// elements in thread_local_computations and global_computations are in post -// order (if computation A has an instruction which calls computation B, then A -// will appear after B in the vector). -Status GatherComputationsByAllocationType( - const HloModule* module, - std::vector* thread_local_computations, - std::vector* global_computations) { - // Create a worklist of computations paired with whether the allocation must - // be thread-local. - std::deque> worklist; - worklist.push_back(std::make_pair(module->entry_computation(), - /*is_thread_local*/ false)); - - // Sets for quickly checking membership. Computations are returned in vectors - // for stable iteration. - FlatSet thread_local_set; - FlatSet global_set; - - while (!worklist.empty()) { - auto worklist_front = worklist.front(); - worklist.pop_front(); - const HloComputation* computation = worklist_front.first; - bool is_thread_local = worklist_front.second; - bool in_thread_local_set = thread_local_set.count(computation) > 0; - bool in_global_set = global_set.count(computation) > 0; - - // If the computation has already been added to the respective set, then - // nothing to do. - if ((is_thread_local && in_thread_local_set) || - (!is_thread_local && in_global_set)) { - continue; - } - - // If the computation has already been added to the other set this is an - // error condition because the global call to the computation (eg, - // while/call) may return a reference to one of the thread-local buffers to - // the calling computation which will become a dangling reference when the - // thread-local is deallocated with the call return. - if ((is_thread_local && in_global_set) || - (!is_thread_local && in_thread_local_set)) { - return InvalidArgument( - "computation %s has conflicting allocation requirements (global " - "and thread-local)", - computation->name().c_str()); - } - - if (is_thread_local) { - thread_local_set.insert(computation); - } else { - global_set.insert(computation); - } - - for (auto* instruction : computation->instructions()) { - for (HloComputation* subcomputation : - instruction->called_computations()) { - switch (instruction->opcode()) { - case HloOpcode::kCall: - case HloOpcode::kConditional: - case HloOpcode::kWhile: - // Call and while must be called from a computation with global - // allocations as they may return references to buffers inside the - // called computation which cannot be thread-local. - if (is_thread_local) { - return InvalidArgument( - "computation %s cannot contain call/while op because it " - "requires thread-local buffer allocations", - computation->name().c_str()); - } - worklist.push_back(std::make_pair(subcomputation, - false)); // Not thread local. - break; - case HloOpcode::kMap: - case HloOpcode::kReduce: - case HloOpcode::kReduceWindow: - case HloOpcode::kSelectAndScatter: - case HloOpcode::kFusion: - // Map/reduce etc computations are always thread-local. - worklist.push_back(std::make_pair(subcomputation, - true)); // Thread local. - break; - default: - return InternalError( - "Unexpected calling opcode: %s", - HloOpcodeString(instruction->opcode()).c_str()); - } - } - } - } - - // Add the computations to the vectors in post order. - for (auto* computation : module->MakeComputationPostOrder()) { - if (thread_local_set.count(computation) > 0) { - thread_local_computations->push_back(computation); - } else if (global_set.count(computation) > 0) { - global_computations->push_back(computation); - } - // If the computation is not reachable from the entry computation, then it - // will not appear in either thread_local_set or global_set. We don't bother - // assigning buffers for these. - } - return Status::OK(); -} - -} // namespace - /* static */ StatusOr> BufferAssigner::Run( const HloModule* module, std::unique_ptr hlo_ordering, @@ -840,7 +910,7 @@ Status BufferAssigner::AssignBuffersForComputation( /*is_thread_local=*/false, /*is_reusable=*/false); allocation->set_entry_computation_parameter( - instruction->parameter_number()); + instruction->parameter_number(), buffer->index()); VLOG(3) << "New allocation #" << allocation->index() << " for entry computation parameter: " << *buffer; continue; @@ -1082,7 +1152,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( if (colocated_set.empty()) { return; } - + VLOG(5) << ColocatedBufferSetsToString(colocated_set, + "Adding colocated buffer set"); // Find existing sets that overlap with at least one buffer from the // colocated_set. The resulting 'overlap_set_indices' will have at most // colocated_buffer_sets->size() entries, and will be in increasing order. @@ -1090,6 +1161,10 @@ void BufferAssigner::AddSetToColocatedBufferSets( for (size_t index = 0; index < colocated_buffer_sets->size(); ++index) { for (const LogicalBuffer* buffer : colocated_set) { if ((*colocated_buffer_sets)[index].count(buffer) > 0) { + VLOG(5) << "Found overlap with existing set on buffer " + << buffer->ToString() << "\n" + << ColocatedBufferSetsToString((*colocated_buffer_sets)[index], + "Overlapping set"); overlap_set_indices.push_back(index); break; } @@ -1101,6 +1176,7 @@ void BufferAssigner::AddSetToColocatedBufferSets( colocated_buffer_sets->emplace_back(); colocated_buffer_sets->back().insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << "No overlap found, new group created"; return; } @@ -1112,6 +1188,8 @@ void BufferAssigner::AddSetToColocatedBufferSets( first->insert(overlap_set.begin(), overlap_set.end()); } first->insert(colocated_set.begin(), colocated_set.end()); + VLOG(5) << ColocatedBufferSetsToString( + *first, "Result of the colocated buffer set merging"); // Remove overlap sets that we just merged. The offset accounts for the fact // that as elements are erased, the indices need to be adjusted. Keep in mind @@ -1122,67 +1200,6 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } -namespace { - -// Checks that points-to set of 'instruction' is unambiguous and distinct -// (ensured by CopyInsertion), then adds the buffer from the points-to set at -// 'index' to 'colocated_set'. -const LogicalBuffer* AddBufferToColocatedSet( - const HloInstruction* instruction, const ShapeIndex& index, - const TuplePointsToAnalysis& points_to_analysis, - std::vector* colocated_set) { - // CopyInsertion ensures root points-to set is unambiguous and distinct. - const auto& points_to = points_to_analysis.GetPointsToSet(instruction); - DCHECK(!points_to.IsAmbiguous()); - colocated_set->push_back(points_to.element(index)[0]); - return colocated_set->back(); -} - -// Given the interference map of a graph (the list of interfering node indices -// for each node), perform graph coloring such that interfering nodes are -// assigned to different colors. Returns the assigned color of the nodes, where -// the colors are represented as integer values [0, color_count). -std::vector ColorInterferenceGraph( - const std::vector>& interference_map) { - const int64 node_count = interference_map.size(); - - // Sort the nodes such that we assign nodes with more interference first. This - // relies on the common heuristic of assigning the most constrained node - // first, but it would be good to investigate other ordering heuristics too. - std::vector nodes(node_count); - std::iota(nodes.begin(), nodes.end(), 0); - std::sort(nodes.begin(), nodes.end(), - [&interference_map](const int64 i, const int64 j) { - return interference_map[i].size() > interference_map[j].size(); - }); - - const int64 kColorUnassigned = -1; - std::vector assigned_colors(node_count, kColorUnassigned); - for (int64 node : nodes) { - // Mark the colors that are already assigned to the neighbors. - std::vector available_colors(node_count, true); - for (int64 neighbor : interference_map[node]) { - int64 color = assigned_colors[neighbor]; - if (color != kColorUnassigned) { - available_colors[color] = false; - } - } - - // Find the color that is not yet assigned to the neighbors. - int64 color = kColorUnassigned; - for (color = 0; color < available_colors.size(); ++color) { - if (available_colors[color]) { - break; - } - } - CHECK_NE(color, kColorUnassigned); - assigned_colors[node] = color; - } - return assigned_colors; -} - -} // namespace - std::vector BufferAssigner::MergeColocatedBufferSets( const std::vector& colocated_buffer_sets, @@ -1411,14 +1428,17 @@ void BufferAssigner::AssignColocatedBufferSets( FlatSet* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; - // Set 'entry_parameter_number' if entry param in 'colocated_buffer_set'. + // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry + // param in 'colocated_buffer_set'. int64 entry_parameter_number = -1; + const ShapeIndex* entry_parameter_shape_idx = nullptr; for (const LogicalBuffer* buffer : colocated_buffer_set) { const HloInstruction* instruction = buffer->instruction(); const HloComputation* computation = instruction->parent(); if (instruction->opcode() == HloOpcode::kParameter && computation == computation->parent()->entry_computation()) { entry_parameter_number = instruction->parameter_number(); + entry_parameter_shape_idx = &buffer->index(); break; } } @@ -1439,7 +1459,8 @@ void BufferAssigner::AssignColocatedBufferSets( // body computation (which updates in place). // Set 'entry_computation_parameter' to indicate that it contains // an entry parameter, and to prevent reuse in MaybeAssignBuffer. - allocation->set_entry_computation_parameter(entry_parameter_number); + allocation->set_entry_computation_parameter( + entry_parameter_number, *entry_parameter_shape_idx); } colocated_allocations->insert(allocation->index()); } else { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 65019b6b17ca2536f0e4b31b8f36d2d167be8661..6b7fd0014d103ef0617afcc5cb3f663554a01aa4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -91,6 +91,13 @@ class BufferAllocation { return parameter_number_; } + // If this allocation is for a parameter of the entry computation, this + // function returns which subshape of the parameter the allocation is for. + const ShapeIndex& param_shape_index() const { + CHECK(is_entry_computation_parameter_); + return param_shape_index_; + } + // Returns whether this allocation is assigned a LogicalBuffer which may // be live out of the entry computation. bool maybe_live_out() const { return maybe_live_out_; } @@ -203,9 +210,11 @@ class BufferAllocation { // Adds a LogicalBuffer to the set assigned to this buffer. void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); - void set_entry_computation_parameter(int64 parameter_number) { + void set_entry_computation_parameter(int64 parameter_number, + ShapeIndex param_shape_index) { is_entry_computation_parameter_ = true; parameter_number_ = parameter_number; + param_shape_index_ = std::move(param_shape_index); } void set_maybe_live_out(bool value) { maybe_live_out_ = value; } void set_index(Index index) { index_ = index; } @@ -235,6 +244,10 @@ class BufferAllocation { // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; + // If this buffer is for an entry computation parameter, which subshape of the + // parameter is it for? + ShapeIndex param_shape_index_; + // Whether the allocation contains a LogicalBuffer which may be live-out of // the entry computation. Note that this flag is conservatively computed by // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_` diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index dab73596e1639eed62151197048ee8d29570b20a..6664496ab6c603c35c7dce923fcf94c54d1ce714 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -72,8 +72,7 @@ CompileOnlyService::CompileAheadOfTime( VersionedComputationHandle versioned_handle = user_computation->GetVersionedHandle(); - // TODO(b/63773457): Track DebugOptions in AotCompilationOptions. - DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); + const DebugOptions& debug_options = options.debug_options(); // Dump computation proto state if flag is set. const string& directory_path = debug_options.xla_dump_computations_to(); diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index e2e9d2a0c048fec6c6ffbeef1223ae0e6aef50d1..0392d4af48a040c4a648f7bf9bf21a62ce03a990 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -86,4 +86,7 @@ Compiler::GetPlatformCompilers() { return compilers->at(platform->id()).get(); } +AotCompilationOptions::AotCompilationOptions() + : debug_options_(legacy_flags::GetDebugOptionsFromFlags()) {} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index 74fd24edf88d44b2dfdc87556b0af43987e69e08..33e19efc72c6d30ccd7e0b3a13f664a4f42208bf 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -79,11 +79,15 @@ class AotCompilationOptions { device_allocator_ = device_allocator; } + const DebugOptions& debug_options() const { return debug_options_; } + DebugOptions* mutable_debug_options() { return &debug_options_; } + protected: - AotCompilationOptions() = default; + AotCompilationOptions(); private: DeviceMemoryAllocator* device_allocator_ = nullptr; + DebugOptions debug_options_; }; // Abstract compiler interface that is subclassed for compilation on a diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cd983bc03e993caed883916de01d75dffdbc4bab..df73c285971e237b6f5492f8a7c587f23646ec1e 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -58,6 +58,45 @@ bool ValueIsReadOnly(const HloValue& value) { return IsConstantValue(value) || IsEntryParameterValue(value); } +// Data structure describing the action which should be taken on parts of a +// computation buffers, with respect to the adding of special case copies. +struct SpecialCaseCopyPolicy { + // Insert a copy if the same buffer is found at multiple indices within the + // output tuple. + bool copy_root_replicated_buffers = false; + // If true, insert a copy if a buffer coming from a constant or a parameter + // is found wihtin the output tuple. + bool copy_parameters_and_constants = false; +}; + +SpecialCaseCopyPolicy GetSpecialCaseCopyPolicy(const CallGraphNode& node, + HloModule* module, + HloComputation* computation) { + SpecialCaseCopyPolicy policy; + if (computation == module->entry_computation()) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + for (const CallSite& site : node.caller_callsites()) { + // The kWhile instruction does not have an handling here, as the + // AddCopiesForWhile() API takes care of adding its own copies. + if (site.instruction()->opcode() == HloOpcode::kConditional) { + policy.copy_parameters_and_constants = true; + policy.copy_root_replicated_buffers = true; + } + } + return policy; +} + +bool ShouldCopyRootValue(const HloValue& value, + const SpecialCaseCopyPolicy& policy) { + if (policy.copy_parameters_and_constants) { + return IsConstantValue(value) || + value.defining_instruction()->opcode() == HloOpcode::kParameter; + } + return false; +} + // Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in // 'indices_to_copy'. Add control edges from the respective kCopy instructions // in deep copy of 'from' to the respective kCopy instruction in the deep copy @@ -729,7 +768,8 @@ class CopyRemover { // has a different operand (the operand of the elided copy). for (const HloUse* copy_use : copy_value_node->uses) { operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy) { + if (copy_use->instruction->opcode() == HloOpcode::kCopy && + ContainsKey(copy_map_, copy_use->instruction)) { copy_map_.at(copy_use->instruction).src = operand_node; } } @@ -956,7 +996,8 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { } TF_RET_CHECK(node.context() == CallContext::kSequential); - const bool is_entry = computation == module->entry_computation(); + SpecialCaseCopyPolicy policy = + GetSpecialCaseCopyPolicy(node, module, computation); HloInstruction* root = computation->root_instruction(); // Mark nondistinct/ambiguous indices. @@ -969,27 +1010,26 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { for (const HloBuffer* buffer : buffers_at_index) { buffer_seen_before |= !seen.insert(buffer).second; } - if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { - VLOG(2) << "Index " << index << " of root of computation " + if (buffers_at_index.size() > 1 || + (buffer_seen_before && policy.copy_root_replicated_buffers)) { + VLOG(2) << "Index " << index << " of computation " << computation->name() << " (" << root->name() << ") has ambiguous or non-distinct buffer. Copying."; add_index_to_copy(root, index); } }); - // For entry instructions, mark any parameter or constant values. - if (is_entry) { - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (ValueIsReadOnly(*value)) { - VLOG(2) << "Root of entry computation (" << root->name() - << ") has constant or entry parameter value at index " - << index << ". Copying."; - add_index_to_copy(root, index); - } + for (const auto& pair : + alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { + const ShapeIndex& index = pair.first; + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (ShouldCopyRootValue(*value, policy)) { + VLOG(2) << "Root of (" << root->name() << ") of computation(" + << computation->name() + << ") has constant or parameter value at index " << index + << ". Copying."; + add_index_to_copy(root, index); } } } @@ -1011,7 +1051,6 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { instruction->parent()->set_root_instruction(deep_copy); } } - return Status::OK(); } @@ -1155,7 +1194,7 @@ bool IsWhileBody(const HloComputation* computation, HloModule* module) { std::unique_ptr call_graph = CallGraph::Build(module); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); bool changed = false; diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 128ee726ea6e4a8b63727fdc9762d865cee1c985..153f062d015e49db11c4c9ae0a2a61e76c020f02 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -1724,8 +1724,58 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { } } +std::unique_ptr MakeBenchmarkWhileBody( + const int num_tuple_inputs) { + auto builder = HloComputation::Builder("benchmark_loop_body"); + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector input_shape(num_tuple_inputs, element_shape); + const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + std::vector gte_nodes(num_tuple_inputs); + for (int i = 0; i < num_tuple_inputs; ++i) { + gte_nodes[i] = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, i)); + } + builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes)); + return builder.Build(); +} + +void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { + tensorflow::testing::StopTiming(); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + CopyInsertion copy_insertion; + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector tuple_params(num_tuple_inputs); + for (int i = 0; i < num_iters; ++i) { + auto builder = HloComputation::Builder("BM_ParallelWhiles"); + HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), + config); + for (int j = 0; j < num_tuple_inputs; ++j) { + tuple_params[j] = builder.AddInstruction( + HloInstruction::CreateParameter(j, element_shape, "")); + } + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple(tuple_params)); + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs)); + HloInstruction* xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {}), xla_while, 0)); + module.AddEntryComputation(builder.Build()); + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + } +} + BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288); TEST_F(CopyInsertionTest, SimpleControlFlowTest) { const string& hlo_string = R"( diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 1a91dd8ff71b93ad9c5d43ca14d3d7d9ccde5bf2..32be0b0c968f2d24f460fc8377c458f2da282112 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -159,13 +159,11 @@ cc_library( deps = [ ":compiler_functor", ":cpu_runtime", - ":cpu_runtime_avx", - ":cpu_runtime_neon", - ":cpu_runtime_sse4_1", ":custom_call_target_registry", ":disassembler", ":external_constant_pool", ":orc_jit_memory_mapper", + ":runtime_fp16", ":runtime_conv2d", ":runtime_fft", ":runtime_fork_join", @@ -185,6 +183,20 @@ cc_library( ] + ORC_JIT_MEMORY_MAPPER_TARGETS, ) +cc_library( + name = "runtime_fp16", + srcs = [ + "runtime_fp16.cc", + ], + hdrs = [ + "runtime_fp16.h", + ], + copts = runtime_copts(), + deps = [ + "//tensorflow/core:framework_lite", + ], +) + cc_library( name = "cpu_executable", srcs = ["cpu_executable.cc"], @@ -408,9 +420,6 @@ cc_library( hdrs = ["compiler_functor.h"], deps = [ ":cpu_runtime", - ":cpu_runtime_avx", - ":cpu_runtime_neon", - ":cpu_runtime_sse4_1", ":disassembler", ":llvm_ir_runtime", "//tensorflow/compiler/xla:statusor", @@ -430,43 +439,6 @@ cc_library( ], ) -cc_library( - name = "cpu_runtime_sse4_1", - srcs = ["cpu_runtime_sse4_1.cc"], - hdrs = ["cpu_runtime_sse4_1.h"], - copts = ["-DEIGEN_AVOID_STL_ARRAY"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - -cc_library( - name = "cpu_runtime_avx", - srcs = ["cpu_runtime_avx.cc"], - hdrs = ["cpu_runtime_avx.h"], - copts = ["-DEIGEN_AVOID_STL_ARRAY"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - -cc_library( - name = "cpu_runtime_neon", - srcs = ["cpu_runtime_neon.cc"], - hdrs = ["cpu_runtime_neon.h"], - # runtime_copts() enables -mfpu=neon - copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - cc_library( name = "cpu_runtime", srcs = [ diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 2723661712383eeed81f2330c78ad15e2064b191..61b2da7a7dce7f6fba46a23cc8e5462a3899a18c 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -37,9 +37,6 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,15 +47,6 @@ limitations under the License. namespace xla { namespace cpu { -/* static */ CompilerFunctor::VectorIntrinsics -CompilerFunctor::AllIntrinsics() { - VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = true; - intrinsics.avx_intrinsics = true; - intrinsics.neon_intrinsics = true; - return intrinsics; -} - /* Create filtered versions of the LLVM Pass Managers to filter out some of the expensive passes. Profiling: @@ -105,8 +93,8 @@ class FilteredPassManager : public llvm::legacy::PassManager { }; } // anonymous namespace -llvm::object::OwningBinary CompilerFunctor:: -operator()(llvm::Module& module) const { +std::unique_ptr CompilerFunctor::operator()( + llvm::Module& module) const { FilteredPassManager module_passes(disable_expensive_passes_); FilteredFunctionPassManager function_passes(&module, disable_expensive_passes_); @@ -169,54 +157,12 @@ operator()(llvm::Module& module) const { codegen_passes.run(module); // Construct ObjectFile from machine code buffer. - std::unique_ptr memory_buffer( + return std::unique_ptr( new llvm::ObjectMemoryBuffer(std::move(stream_buffer))); - llvm::Expected> - object_file_or_error = llvm::object::ObjectFile::createObjectFile( - memory_buffer->getMemBufferRef()); - CHECK(object_file_or_error); - - std::unique_ptr object_file = - std::move(object_file_or_error.get()); - if (VLOG_IS_ON(2)) { - StatusOr disassembly_status = - disassembler_->DisassembleObjectFile(*object_file); - if (disassembly_status.ok()) { - auto result = disassembly_status.ValueOrDie(); - XLA_VLOG_LINES(2, result.text); - VLOG(2) << "compiled code size: " << result.code_size_bytes << " bytes"; - } - } - - return llvm::object::OwningBinary( - std::move(object_file), std::move(memory_buffer)); } -namespace { -// Returns the set of vectorized library functions supported for the target. -std::vector VectorFunctionsForTargetLibraryInfoImpl( - llvm::Triple::ArchType arch, llvm::StringRef feature_string, - CompilerFunctor::VectorIntrinsics const& available_intrinsics) { - std::vector vector_functions; - - const llvm::VecDesc four_wide_vector_functions_neon[] = { - {"logf", runtime::kLogV4F32NEONSymbolName, 4}, - {"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4}, - }; - - const llvm::VecDesc four_wide_vector_functions_sse[] = { - {"logf", runtime::kLogV4F32SSESymbolName, 4}, - {"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4}, - }; - - const llvm::VecDesc eight_wide_vector_functions_avx[] = { - {"logf", runtime::kLogV8F32AVXSymbolName, 8}, - {"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8}, - }; - - // These functions are generated by XLA as LLVM IR, so they're always - // available. - const llvm::VecDesc ir_vector_functions[] = { +static std::vector VectorFunctionsForTargetLibraryInfoImpl() { + std::vector result = { {"tanhf", runtime::kTanhV4F32SymbolName, 4}, {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4}, @@ -228,50 +174,15 @@ std::vector VectorFunctionsForTargetLibraryInfoImpl( {"expf", runtime::kExpV8F32SymbolName, 8}, {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8}, - }; - llvm::SmallVector features; - feature_string.split(features, ',', -1, /*KeepEmpty=*/false); - auto has_feature = [&features](const llvm::StringRef feature) { - return std::find(features.begin(), features.end(), feature) != - features.end(); - }; + {"logf", runtime::kLogV4F32SymbolName, 4}, + {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4}, - switch (arch) { - case llvm::Triple::x86: - case llvm::Triple::x86_64: { - if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(four_wide_vector_functions_sse), - std::end(four_wide_vector_functions_sse)); - } - if (has_feature("+avx") && available_intrinsics.avx_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(eight_wide_vector_functions_avx), - std::end(eight_wide_vector_functions_avx)); - } - break; - } - case llvm::Triple::arm: - case llvm::Triple::aarch64: { - if (has_feature("+neon") && available_intrinsics.neon_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(four_wide_vector_functions_neon), - std::end(four_wide_vector_functions_neon)); - } - break; - } - default: - break; - } - - vector_functions.insert(vector_functions.end(), - std::begin(ir_vector_functions), - std::end(ir_vector_functions)); - - return vector_functions; + {"logf", runtime::kLogV8F32SymbolName, 8}, + {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8}, + }; + return result; } -} // namespace void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { @@ -279,9 +190,7 @@ void CompilerFunctor::AddTargetInfoPasses( auto target_library_info_impl = MakeUnique(target_triple); target_library_info_impl->addVectorizableFunctions( - VectorFunctionsForTargetLibraryInfoImpl( - target_triple.getArch(), target_machine_->getTargetFeatureString(), - available_intrinsics_)); + VectorFunctionsForTargetLibraryInfoImpl()); passes->add( new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); passes->add(createTargetTransformInfoWrapperPass( diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 8cdd049e7b773bdc455db627ff1749997d621ee4..c38b896c5019b48fd2a16a51abd59e12ebdb29eb 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -31,21 +31,10 @@ namespace cpu { // Orc JIT compile layer. class CompilerFunctor { public: - // Describes the set of vector intrinsics available to the generated code. - struct VectorIntrinsics { - bool sse_intrinsics; - bool avx_intrinsics; - bool neon_intrinsics; - }; - - // Returns a VectorIntrinsics where all intrinsics are available. - static VectorIntrinsics AllIntrinsics(); - explicit CompilerFunctor( llvm::TargetMachine* target_machine, const Disassembler* disassembler, int opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, - const VectorIntrinsics& available_intrinsics, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, LLVMCompiler::ModuleHook post_optimization_hook = nullptr) : target_machine_(target_machine), @@ -54,12 +43,11 @@ class CompilerFunctor { optimize_for_size_(optimize_for_size), enable_fast_math_(enable_fast_math), disable_expensive_passes_(disable_expensive_passes), - available_intrinsics_(available_intrinsics), pre_optimization_hook_(pre_optimization_hook), post_optimization_hook_(post_optimization_hook) {} // Compile a Module to an ObjectFile. - llvm::object::OwningBinary operator()( + std::unique_ptr operator()( llvm::Module& module) const; // NOLINT private: @@ -78,7 +66,6 @@ class CompilerFunctor { const bool optimize_for_size_; const bool enable_fast_math_; const bool disable_expensive_passes_; - const VectorIntrinsics available_intrinsics_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index d13a97bcc9a84afb22556389b4cdcd985f58d445..387806e24aad0d5f28cb104507ef6cc136ffd779 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -888,13 +888,11 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook, - post_optimization_ir_dump_hook); - llvm::object::OwningBinary object_file = + pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); + std::unique_ptr object_file = compiler_functor(llvm_module); - llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); - ObjectFileData object_file_data(object_file_data_ref.begin(), - object_file_data_ref.end()); + ObjectFileData object_file_data(object_file->getBufferStart(), + object_file->getBufferEnd()); BufferSizes buffer_sizes; for (const BufferAllocation& allocation : assignment->Allocations()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 802d0a6fb46890b31d14b1fbf3b2e7d6520caccc..c053703c3524a47ee1de9681c1b986edbf109430 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -63,7 +63,7 @@ CpuExecutable::CpuExecutable( assignment_(std::move(assignment)) { // Resolve symbols in the constructor rather than at execution time to avoid // races because FindSymbol is not thread safe. - llvm::JITSymbol sym = jit_->FindSymbol(entry_function_name); + llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry_function_name); // We expect to find the symbol provided with entry_function_name; otherwise // this is an internal error. CHECK(sym) << "Symbol " << entry_function_name << " not found."; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 1ef45dbec39a0880ebb123ba3fcd1fd6c89eb39a..40ace963270e8cead47cc731cc326351178dff7d 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -35,6 +35,8 @@ extern const char* const kEigenMatMulF32SymbolName = "__xla_cpu_runtime_EigenMatMulF32"; extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; +extern const char* const kEigenConvF16SymbolName = + "__xla_cpu_runtime_EigenConvF16"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; @@ -42,6 +44,8 @@ extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF64"; +extern const char* const kEigenSingleThreadedConvF16SymbolName = + "__xla_cpu_runtime_EigenSingleThreadedConvF16"; extern const char* const kEigenSingleThreadedConvF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedConvF32"; extern const char* const kAcquireInfeedBufferForDequeueSymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 3e1f08071119c938619d02777513e5b834077118..2141dfe1cedd6f9674acc348152574b4fd30895b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -43,10 +43,12 @@ namespace runtime { // because it is a symbol in the cpu_runtime library. extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; +extern const char* const kEigenConvF16SymbolName; extern const char* const kEigenConvF32SymbolName; extern const char* const kEigenFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; +extern const char* const kEigenSingleThreadedConvF16SymbolName; extern const char* const kEigenSingleThreadedConvF32SymbolName; extern const char* const kAcquireInfeedBufferForDequeueSymbolName; extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h deleted file mode 100644 index f473c689f297d7c7c3df18b128b99caee0239ea0..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context -// which is used to cache expensive state and resources utilized by the -// aforementioned functions. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ - -#include "tensorflow/core/platform/macros.h" - -#if defined(__AVX__) -#include -#define TF_XLA_HAS_AVX -#endif - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kLogV8F32AVXSymbolName; - -#ifdef TF_XLA_HAS_AVX -typedef __m256 V8F32AVX; -#endif -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -#ifdef TF_XLA_HAS_AVX -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x); - -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x); -#endif -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc deleted file mode 100644 index 8099b722f10ecb83f7cf6c58ba2abb783478b97f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef TF_XLA_HAS_NEON - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x) { - return Eigen::internal::pexp(x); -} - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::plog(p); -} - -#endif // TF_XLA_HAS_NEON - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON"; -const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h deleted file mode 100644 index 2f5d1a872aaf3868d6d27f88a4f05c778d45660f..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. - -#include "tensorflow/core/platform/macros.h" - -#ifdef __ARM_NEON__ -// For the other runtimes (AVX, SSE4.1) we define the vector type directly using -// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM -// NEON SIMD types is not portable, so the type has to come from -#include -#define TF_XLA_HAS_NEON -#endif // __ARM_NEON__ - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV4F32NEONSymbolName; -extern const char *const kLogV4F32NEONSymbolName; - -#ifdef TF_XLA_HAS_NEON -typedef float32x4_t V4F32NEON; -#endif // TF_XLA_HAS_NEON - -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -#ifdef TF_XLA_HAS_NEON -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x); - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x); -#endif // TF_XLA_HAS_NEON -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h deleted file mode 100644 index 3b3d18112aef5c091b8b2eb67775c79c450ebce7..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ /dev/null @@ -1,59 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context -// which is used to cache expensive state and resources utilized by the -// aforementioned functions. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ - -#include "tensorflow/core/platform/macros.h" - -// MSVC does not have __SSE4_1__ macro. Eigen enables EIGEN_VECTORIZE_SSE4_1 -// when __AVX__ is defined, we should do the same. -#if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) -#include -#define TF_XLA_HAS_SSE4_1 -#endif - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kLogV4F32SSESymbolName; - -#ifdef TF_XLA_HAS_SSE4_1 -typedef __m128 V4F32SSE; -#endif - -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -#ifdef TF_XLA_HAS_SSE4_1 -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x); -#endif -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index c9fc586b9a4c06eb9e1f111d8f9bd2f717990aab..cfe7c9c3af0be109ac8a86753e880e2bcbceba41 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -549,7 +549,7 @@ DotOpEmitter::DotOpEmitter( const HloModuleConfig& hlo_module_config, const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F32 == type || F64 == type || C64 == type); + TF_RET_CHECK(F16 == type || F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, lhs_array, rhs_array, addend_array, executable_run_options_value, ir_builder, diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index d9eeb1c3bdc2a8058992de0e13045a240bf56b8d..4dffaee87f6b33933b58c8c58478eec918569197 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -801,7 +801,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32, F64, C64})); + /*supported_types=*/{F16, F32, F64, C64})); const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); if (dnums.lhs_batch_dimensions_size() > 0 || dnums.rhs_batch_dimensions_size() > 0) { @@ -849,7 +849,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { const auto& window = convolution->window(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32, C64})); + /*supported_types=*/{F16, F32, C64})); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); @@ -928,25 +928,30 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { int64 rhs_col_dilation = one_dim_convolution ? 1 : window.dimensions(1).window_dilation(); - // Args have been computed, make the call. - llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo(); + PrimitiveType primitive_type = lhs->shape().element_type(); + llvm::Type* ir_ptr_type = primitive_type == F16 + ? ir_builder_.getHalfTy()->getPointerTo() + : ir_builder_.getFloatTy()->getPointerTo(); llvm::Type* int64_type = ir_builder_.getInt64Ty(); llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); llvm::FunctionType* conv_type = llvm::FunctionType::get( ir_builder_.getVoidTy(), - {int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type, - int64_type, int64_type, int64_type, int64_type}, + {int8_ptr_type, ir_ptr_type, ir_ptr_type, ir_ptr_type, int64_type, + int64_type, int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type, int64_type, + int64_type, int64_type, int64_type, int64_type}, /*isVarArg=*/false); bool multi_threaded_eigen = hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen(); const char* fn_name = - (multi_threaded_eigen - ? runtime::kEigenConvF32SymbolName - : runtime::kEigenSingleThreadedConvF32SymbolName); + primitive_type == F16 + ? (multi_threaded_eigen + ? runtime::kEigenConvF16SymbolName + : runtime::kEigenSingleThreadedConvF16SymbolName) + : (multi_threaded_eigen + ? runtime::kEigenConvF32SymbolName + : runtime::kEigenSingleThreadedConvF32SymbolName); llvm::Function* conv_func = llvm::cast( module_->getOrInsertFunction(fn_name, conv_type)); conv_func->setCallingConv(llvm::CallingConv::C); @@ -956,9 +961,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { conv_func, { GetExecutableRunOptionsArgument(), ir_builder_.CreateBitCast( - GetEmittedValueFor(convolution), float_ptr_type), - ir_builder_.CreateBitCast(lhs_address, float_ptr_type), - ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + GetEmittedValueFor(convolution), ir_ptr_type), + ir_builder_.CreateBitCast(lhs_address, ir_ptr_type), + ir_builder_.CreateBitCast(rhs_address, ir_ptr_type), ir_builder_.getInt64(input_batch), ir_builder_.getInt64(input_rows), ir_builder_.getInt64(input_cols), diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 38fcd278e9298f8458fd37e81458a2b9a150bb0e..2e5cc96098241415b82f225afc81981f3e1069e0 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -21,6 +21,7 @@ limitations under the License. #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -31,6 +32,8 @@ const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; +const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; +const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; namespace { llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, @@ -60,7 +63,8 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, CHECK_EQ(input->getType(), vsl.vector_type()); // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/-9.0, /*high=*/9.0); + llvm::Value* input_clamped = + vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); std::array numerator_coeffs{ -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, @@ -72,16 +76,18 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, 4.89352518554385e-03f}; llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); - llvm::Value* numerator = vsl.SplatFloat(numerator_coeffs[0]); + llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = vsl.MulAdd(input_squared, numerator, numerator_coeffs[i]); + numerator = + vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); } numerator = vsl.Mul(input_clamped, numerator); - llvm::Value* denominator = vsl.SplatFloat(denominator_coeffs[0]); + llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = vsl.MulAdd(input_squared, denominator, denominator_coeffs[i]); + denominator = vsl.MulAdd(input_squared, denominator, + GetIeeeF32(denominator_coeffs[i])); } llvm::Value* result = vsl.Div(numerator, denominator); @@ -116,24 +122,27 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, // This implements the same polynomial approximation as implemented in Eigen3. - const double exp_hi = 88.3762626647950; - const double exp_lo = -88.3762626647949; + const llvm::APFloat half = GetIeeeF32(0.5); + const llvm::APFloat one = GetIeeeF32(1.0); + + const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950); + const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949); - const double cephes_LOG2EF = 1.44269504088896341; - const double cephes_exp_C1 = 0.693359375; - const double cephes_exp_C2 = -2.12194440e-4; + const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341); + const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375); + const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4); - const double cephes_exp_p0 = 1.9875691500E-4; - const double cephes_exp_p1 = 1.3981999507E-3; - const double cephes_exp_p2 = 8.3334519073E-3; - const double cephes_exp_p3 = 4.1665795894E-2; - const double cephes_exp_p4 = 1.6666665459E-1; - const double cephes_exp_p5 = 5.0000001201E-1; + const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4); + const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3); + const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3); + const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2); + const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); + const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); llvm::Value* input = &*vector_exp_function->arg_begin(); llvm::Value* input_clamped = vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); - llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, 0.5)); + llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); llvm::Value* x = vsl.Sub(input_clamped, tmp); @@ -146,7 +155,7 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, y = vsl.MulAdd(y, x, cephes_exp_p4); y = vsl.MulAdd(y, x, cephes_exp_p5); y = vsl.MulAdd(y, z, x); - y = vsl.Add(1.0, y); + y = vsl.Add(one, y); // VectorSupportLibrary (intentionally) can't juggle more than one type at a // time so drop down to IRBuilder for this bit. @@ -167,9 +176,129 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, ir_builder.CreateRet(result); - CHECK(!llvm::verifyFunction(*vector_exp_function)); + DCHECK(!llvm::verifyFunction(*vector_exp_function)); return vector_exp_function; } + +llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, + llvm::StringRef function_name, + int vector_width, + bool enable_fast_math) { + llvm::Function* vector_log_function = module->getFunction(function_name); + if (vector_log_function == nullptr) { + // If the function declaration is not present in the module, there can't be + // any calls to resolve. Don't emit the function in this case. + return nullptr; + } + + llvm::LLVMContext* context = &module->getContext(); + + llvm::BasicBlock* vector_log_body = + llvm::BasicBlock::Create(*context, "body", vector_log_function); + + llvm::IRBuilder<> ir_builder(vector_log_body); + llvm::FastMathFlags fast_math_flags; + fast_math_flags.setFast(); + ir_builder.setFastMathFlags(fast_math_flags); + + llvm::Value* input = &*vector_log_function->arg_begin(); + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32"); + + const llvm::APFloat half = GetIeeeF32(0.5); + const llvm::APFloat one = GetIeeeF32(1.0); + + // This implements the same polynomial approximation as implemented in Eigen3. + // Returns NaN for x < 0, -INF for x = 0 + const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524); + const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2); + const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1); + const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1); + const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1); + const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1); + const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1); + const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1); + const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1); + const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1); + const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4); + const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375); + + // The smallest non denormalized float number. + const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); + const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); + const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); + + // invalid_mask is set if x is negative or NaN (and therefore output + // must be NaN). + llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); + llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + + // Cut off denormalized stuff. + input = vsl.Max(min_norm_pos, input); + + // VectorSupportLibrary (intentionally) can't juggle more than one type at a + // time so drop down to IRBuilder for this bit. + llvm::Value* vector_constant_0x7f = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + llvm::Value* vector_constant_23 = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + llvm::Type* i32_vector_type = + llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + + llvm::Value* emm0 = ir_builder.CreateLShr( + ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23); + + // Keep only the fractional part. + input = vsl.FloatAnd(input, inv_mant_mask); + input = vsl.FloatOr(input, half); + + emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = + vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type())); + + // part2: + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); + llvm::Value* tmp = vsl.FloatAnd(input, mask); + input = vsl.Sub(input, one); + e = vsl.Sub(e, vsl.FloatAnd(mask, one)); + input = vsl.Add(input, tmp); + + llvm::Value* x2 = vsl.Mul(input, input); + llvm::Value* x3 = vsl.Mul(x2, input); + + llvm::Value *y, *y1, *y2; + y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); + y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); + y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); + y = vsl.MulAdd(y, input, cephes_log_p2); + y1 = vsl.MulAdd(y1, input, cephes_log_p5); + y2 = vsl.MulAdd(y2, input, cephes_log_p8); + y = vsl.MulAdd(y, x3, y1); + y = vsl.MulAdd(y, x3, y2); + y = vsl.Mul(y, x3); + + y1 = vsl.Mul(cephes_log_q1, e); + tmp = vsl.Mul(half, x2); + y = vsl.Add(y, y1); + input = vsl.Sub(input, tmp); + y2 = vsl.Mul(cephes_log_q2, e); + input = vsl.Add(input, y); + input = vsl.Add(input, y2); + + // Negative arg will be NAN, 0 will be -INF. + llvm::Value* or_lhs = + vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); + llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); + llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); + + ir_builder.CreateRet(result); + + DCHECK(!llvm::verifyFunction(*vector_log_function)); + return vector_log_function; +} } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { @@ -187,11 +316,21 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, /*vector_width=*/8, enable_fast_math); + auto* log_v4f32 = + EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, + /*vector_width=*/4, enable_fast_math); + auto* log_v8f32 = + EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, + /*vector_width=*/8, enable_fast_math); + // Gather all the call sites, force inline them and then delete the vector // function bodies. + // + // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? std::vector calls_to_inline; - for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) { + for (auto* function : + {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { if (function != nullptr) { for (auto* user : function->users()) { calls_to_inline.push_back(llvm::cast(user)); @@ -204,7 +343,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); } - for (auto* function : {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32}) { + for (auto* function : + {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { if (function != nullptr) { function->eraseFromParent(); } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h index 90050c44594f28614039bd85b20c4ecc0945907a..5553972677512617ccb6ac4f57a4d33400b664e3 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h @@ -27,6 +27,8 @@ extern const char* const kTanhV4F32SymbolName; extern const char* const kTanhV8F32SymbolName; extern const char* const kExpV4F32SymbolName; extern const char* const kExpV8F32SymbolName; +extern const char* const kLogV4F32SymbolName; +extern const char* const kLogV8F32SymbolName; // The following CPU runtime functions have LLVM-IR only implementations: // diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index cd997f07890cdc1d9a546ede58cc1d992b6416ae..07a9f0efcb64db4b2ff0c6518d4b48eee9a505e0 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -394,7 +394,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( for (auto& entry : *function_names_) { tensorflow::mutex_lock lock(jit_mutex_); HloInstruction* instruction = entry.first; - llvm::JITSymbol sym = jit_->FindSymbol(entry.second); + llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry.second); TF_RET_CHECK(sym); InsertOrDie( &functions, instruction, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc index c2f64eb27a554d17ebe2a94dba334fe378bd7254..3905e7ff2a14d25813e345399e692f9e0f4bd0af 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.cc @@ -34,7 +34,26 @@ TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF32( int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { const xla::ExecutableRunOptions* run_options = static_cast(run_options_ptr); - tensorflow::xla::EigenConvF32Impl( + tensorflow::xla::EigenConvImpl( + *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, + input_rows, input_cols, input_channels, kernel_rows, kernel_cols, + kernel_channels, kernel_filters, output_rows, output_cols, row_stride, + col_stride, padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); +} + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenConvF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols, + int64 input_channels, int64 kernel_rows, int64 kernel_cols, + int64 kernel_channels, int64 kernel_filters, int64 output_rows, + int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top, + int64 padding_bottom, int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, + int64 rhs_col_dilation) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + tensorflow::xla::EigenConvImpl( *run_options->intra_op_thread_pool(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h index 05ae094691fd9a7ca83b902145c0750fafdc529a..39e20ed45639040110b99ddb52eb6f6dab26dfaa 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d.h @@ -34,6 +34,20 @@ extern void __xla_cpu_runtime_EigenConvF32( tensorflow::int64 lhs_col_dilation, tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation); +extern void __xla_cpu_runtime_EigenConvF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, + tensorflow::int64 input_batch, tensorflow::int64 input_rows, + tensorflow::int64 input_cols, tensorflow::int64 input_channels, + tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols, + tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters, + tensorflow::int64 output_rows, tensorflow::int64 output_cols, + tensorflow::int64 row_stride, tensorflow::int64 col_stride, + tensorflow::int64 padding_top, tensorflow::int64 padding_bottom, + tensorflow::int64 padding_left, tensorflow::int64 padding_right, + tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation, + tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation); + } // extern "C" #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_CONV2D_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h index 02f45fee0f1b8cd1125ec6a97f01e0028137bb69..85af63bb032ce33bdd188d6e5bcd78a726d5d9fa 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_conv2d_impl.h @@ -24,26 +24,27 @@ limitations under the License. namespace tensorflow { namespace xla { -template -void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs, - float* rhs, int64 input_batch, int64 input_rows, - int64 input_cols, int64 input_channels, int64 kernel_rows, - int64 kernel_cols, int64 kernel_channels, - int64 kernel_filters, int64 output_rows, - int64 output_cols, int64 row_stride, int64 col_stride, - int64 padding_top, int64 padding_bottom, - int64 padding_left, int64 padding_right, - int64 lhs_row_dilation, int64 lhs_col_dilation, - int64 rhs_row_dilation, int64 rhs_col_dilation) { - const Eigen::TensorMap, +template +void EigenConvImpl(const EigenDevice& device, ScalarType* out, ScalarType* lhs, + ScalarType* rhs, int64 input_batch, int64 input_rows, + int64 input_cols, int64 input_channels, int64 kernel_rows, + int64 kernel_cols, int64 kernel_channels, + int64 kernel_filters, int64 output_rows, int64 output_cols, + int64 row_stride, int64 col_stride, int64 padding_top, + int64 padding_bottom, int64 padding_left, + int64 padding_right, int64 lhs_row_dilation, + int64 lhs_col_dilation, int64 rhs_row_dilation, + int64 rhs_col_dilation) { + const Eigen::TensorMap, Eigen::Aligned> input(lhs, input_batch, input_rows, input_cols, input_channels); - const Eigen::TensorMap, + const Eigen::TensorMap, Eigen::Aligned> kernel(rhs, kernel_rows, kernel_cols, kernel_channels, kernel_filters); - Eigen::TensorMap, Eigen::Aligned> + Eigen::TensorMap, + Eigen::Aligned> output(out, input_batch, output_rows, output_cols, kernel_filters); Eigen::array, 1> contract_dims; @@ -75,7 +76,7 @@ void EigenConvF32Impl(const EigenDevice& device, float* out, float* lhs, row_stride, rhs_col_dilation, rhs_row_dilation, lhs_col_dilation, lhs_row_dilation, padding_left, padding_right, padding_top, - padding_bottom, 0.0f) + padding_bottom, static_cast(0.0f)) .reshape(pre_contract_dims) .contract(kernel.reshape(kernel_dims), contract_dims) .reshape(post_contract_dims); diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc b/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc new file mode 100644 index 0000000000000000000000000000000000000000..af0275c8bd00c82220fbe116eb90d2692393713b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fp16.cc @@ -0,0 +1,133 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" +#include "tensorflow/core/platform/macros.h" + +namespace { +using tensorflow::uint16; +using tensorflow::uint32; + +// Helper class that lets us access the underlying bit representation +// of a float without breaking C++ strict aliasing. +class AliasedFloatInt { + public: + static_assert(sizeof(float) == sizeof(uint32), ""); + + static AliasedFloatInt FromFloat(float f) { + AliasedFloatInt value; + value.set_float(f); + return value; + } + + static AliasedFloatInt FromUInt(uint32 u) { + AliasedFloatInt value; + value.set_uint(u); + return value; + } + + void set_float(float f) { memcpy(&value_, &f, sizeof(f)); } + float as_float() const { + float f; + memcpy(&f, &value_, sizeof(f)); + return f; + } + + void set_uint(uint32 u) { value_ = u; } + uint32 as_uint() const { return value_; } + + private: + uint32 value_; +}; +} // namespace + +// __gnu_f2h_ieee and __gnu_h2f_ieee are marked as weak symbols so if XLA is +// built with compiler-rt (that also defines these symbols) we don't get a +// duplicate definition linker error. Making these symbols weak also ensures +// that the compiler-rt definitions "win", but that isn't essential. + +// Algorithm copied from Eigen. +uint16 TF_ATTRIBUTE_WEAK __gnu_f2h_ieee(float float_value) { + AliasedFloatInt f = AliasedFloatInt::FromFloat(float_value); + + const AliasedFloatInt f32infty = AliasedFloatInt::FromUInt(255 << 23); + const AliasedFloatInt f16max = AliasedFloatInt::FromUInt((127 + 16) << 23); + const AliasedFloatInt denorm_magic = + AliasedFloatInt::FromUInt(((127 - 15) + (23 - 10) + 1) << 23); + unsigned int sign_mask = 0x80000000u; + uint32 o = static_cast(0x0u); + + unsigned int sign = f.as_uint() & sign_mask; + f.set_uint(f.as_uint() ^ sign); + + // NOTE all the integer compares in this function can be safely + // compiled into signed compares since all operands are below + // 0x80000000. Important if you want fast straight SSE2 code + // (since there's no unsigned PCMPGTD). + + if (f.as_uint() >= + f16max.as_uint()) { // result is Inf or NaN (all exponent bits set) + o = (f.as_uint() > f32infty.as_uint()) ? 0x7e00 + : 0x7c00; // NaN->qNaN and Inf->Inf + } else { // (De)normalized number or zero + if (f.as_uint() < (113 << 23)) { // resulting FP16 is subnormal or zero + // use a magic value to align our 10 mantissa bits at the bottom of + // the float. as long as FP addition is round-to-nearest-even this + // just works. + f.set_float(f.as_float() + denorm_magic.as_float()); + + // and one integer subtract of the bias later, we have our final float! + o = static_cast(f.as_uint() - denorm_magic.as_uint()); + } else { + unsigned int mant_odd = + (f.as_uint() >> 13) & 1; // resulting mantissa is odd + + // update exponent, rounding bias part 1 + f.set_uint(f.as_uint() + (static_cast(15 - 127) << 23) + + 0xfff); + // rounding bias part 2 + f.set_uint(f.as_uint() + mant_odd); + // take the bits! + o = static_cast(f.as_uint() >> 13); + } + } + + o |= static_cast(sign >> 16); + return o; +} + +// Algorithm copied from Eigen. +float TF_ATTRIBUTE_WEAK __gnu_h2f_ieee(uint16 h) { + const AliasedFloatInt magic = AliasedFloatInt::FromUInt(113 << 23); + const unsigned int shifted_exp = 0x7c00 << 13; // exponent mask after shift + AliasedFloatInt o; + + o.set_uint((h & 0x7fff) << 13); // exponent/mantissa bits + unsigned int exp = shifted_exp & o.as_uint(); // just the exponent + o.set_uint(o.as_uint() + ((127 - 15) << 23)); // exponent adjust + + // handle exponent special cases + if (exp == shifted_exp) { // Inf/NaN? + o.set_uint(o.as_uint() + ((128 - 16) << 23)); // extra exp adjust + } else if (exp == 0) { // Zero/Denormal? + o.set_uint(o.as_uint() + (1 << 23)); // extra exp adjust + o.set_float(o.as_float() - magic.as_float()); // renormalize + } + + o.set_uint(o.as_uint() | (h & 0x8000) << 16); // sign bit + return o.as_float(); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fp16.h b/tensorflow/compiler/xla/service/cpu/runtime_fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..01d92d031904af99884c2583a8c7b5086b289d44 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fp16.h @@ -0,0 +1,27 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ + +#include "tensorflow/core/platform/types.h" + +// Converts an F32 value to a F16. +extern "C" tensorflow::uint16 __gnu_f2h_ieee(float); + +// Converts an F16 value to a F32. +extern "C" float __gnu_h2f_ieee(tensorflow::uint16); + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FP16_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc index d0b0e11ac0f9fd06e384c2bb5e6296edd0825f5c..5afccc6a86e2df468e3e3e874cf0f4d4e1342a88 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.cc @@ -21,6 +21,24 @@ limitations under the License. using tensorflow::int64; +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void +__xla_cpu_runtime_EigenSingleThreadedConvF16( + const void* run_options_ptr, Eigen::half* out, Eigen::half* lhs, + Eigen::half* rhs, int64 input_batch, int64 input_rows, int64 input_cols, + int64 input_channels, int64 kernel_rows, int64 kernel_cols, + int64 kernel_channels, int64 kernel_filters, int64 output_rows, + int64 output_cols, int64 row_stride, int64 col_stride, int64 padding_top, + int64 padding_bottom, int64 padding_left, int64 padding_right, + int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, + int64 rhs_col_dilation) { + tensorflow::xla::EigenConvImpl( + Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, + input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, + kernel_filters, output_rows, output_cols, row_stride, col_stride, + padding_top, padding_bottom, padding_left, padding_right, + lhs_row_dilation, lhs_col_dilation, rhs_row_dilation, rhs_col_dilation); +} + TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenSingleThreadedConvF32( const void* run_options_ptr, float* out, float* lhs, float* rhs, @@ -30,7 +48,7 @@ __xla_cpu_runtime_EigenSingleThreadedConvF32( int64 row_stride, int64 col_stride, int64 padding_top, int64 padding_bottom, int64 padding_left, int64 padding_right, int64 lhs_row_dilation, int64 lhs_col_dilation, int64 rhs_row_dilation, int64 rhs_col_dilation) { - tensorflow::xla::EigenConvF32Impl( + tensorflow::xla::EigenConvImpl( Eigen::DefaultDevice(), out, lhs, rhs, input_batch, input_rows, input_cols, input_channels, kernel_rows, kernel_cols, kernel_channels, kernel_filters, output_rows, output_cols, row_stride, col_stride, diff --git a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h index 8ae1a42149bde26ca2f510ad47e76ae47f34a977..f216bd0152aa93b8753d881938c63a9cabea899b 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h @@ -20,6 +20,20 @@ limitations under the License. extern "C" { +extern void __xla_cpu_runtime_EigenSingleThreadedConvF16( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, + Eigen::half* out, Eigen::half* lhs, Eigen::half* rhs, + tensorflow::int64 input_batch, tensorflow::int64 input_rows, + tensorflow::int64 input_cols, tensorflow::int64 input_channels, + tensorflow::int64 kernel_rows, tensorflow::int64 kernel_cols, + tensorflow::int64 kernel_channels, tensorflow::int64 kernel_filters, + tensorflow::int64 output_rows, tensorflow::int64 output_cols, + tensorflow::int64 row_stride, tensorflow::int64 col_stride, + tensorflow::int64 padding_top, tensorflow::int64 padding_bottom, + tensorflow::int64 padding_left, tensorflow::int64 padding_right, + tensorflow::int64 lhs_row_dilation, tensorflow::int64 lhs_col_dilation, + tensorflow::int64 rhs_row_dilation, tensorflow::int64 rhs_col_dilation); + extern void __xla_cpu_runtime_EigenSingleThreadedConvF32( const void* /* xla::ExecutableRunOptions* */ run_options_ptr, float* out, float* lhs, float* rhs, tensorflow::int64 input_batch, diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index 2f4468cca7bc202b6bdcd8870311194f60c35f96..e8a375d63791cd9a94f77af4ef5e74d2cb7e4361 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -28,14 +28,12 @@ limitations under the License. #include "llvm/Support/Host.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fp16.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -47,36 +45,6 @@ namespace xla { namespace cpu { namespace { -// A simple SymbolResolver that delegates to the host dynamic linker. -class SimpleResolver : public llvm::LegacyJITSymbolResolver { - public: - explicit SimpleResolver(ExternalConstantPool* external_constant_pool) - : external_constant_pool_(external_constant_pool) {} - - llvm::JITSymbol findSymbol(const std::string& name) override { - if (const uint8* from_constant_pool = - external_constant_pool_->Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); - if (func_addr == nullptr) { - return nullptr; - } - llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), - llvm::JITSymbolFlags::None); - return symbol_info; - } - llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { - return nullptr; - } - - private: - ExternalConstantPool* external_constant_pool_; -}; - llvm::SmallVector DetectMachineAttributes() { llvm::SmallVector result; llvm::StringMap host_features; @@ -101,27 +69,6 @@ llvm::StringRef GetHostCpuName() { cpu_name.consume_back("-avx512"); return cpu_name; } - -CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { - CompilerFunctor::VectorIntrinsics intrinsics; -#ifdef TF_XLA_HAS_SSE4_1 - intrinsics.sse_intrinsics = true; -#else - intrinsics.sse_intrinsics = false; -#endif -#ifdef TF_XLA_HAS_AVX - intrinsics.avx_intrinsics = true; -#else - intrinsics.avx_intrinsics = false; -#endif -#ifdef TF_XLA_HAS_NEON - intrinsics.neon_intrinsics = true; -#else - intrinsics.neon_intrinsics = false; -#endif - return intrinsics; -} - } // namespace SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, @@ -143,43 +90,47 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, execution_session_(string_pool_), symbol_resolver_(llvm::orc::createLegacyLookupResolver( [this](const std::string& name) -> llvm::JITSymbol { - if (const uint8* from_constant_pool = - external_constant_pool_.Find(string(name))) { - return llvm::JITEvaluatedSymbol( - reinterpret_cast(from_constant_pool), - llvm::JITSymbolFlags::None); - } - - void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); - if (func_addr == nullptr) { - return nullptr; - } - llvm::JITEvaluatedSymbol symbol_info( - reinterpret_cast(func_addr), - llvm::JITSymbolFlags::None); - return symbol_info; + return this->ResolveRuntimeSymbol(name); }, [](llvm::Error Err) { cantFail(std::move(Err), "lookupFlags failed"); })), - object_layer_( - execution_session_, - [](llvm::orc::VModuleKey) { - return std::make_shared( - orc_jit_memory_mapper::GetInstance()); - }, - [this](llvm::orc::VModuleKey K) { return symbol_resolver_; }), - compile_layer_( - object_layer_, - CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, - optimize_for_size, enable_fast_math, - disable_expensive_passes, GetAvailableIntrinsics(), - std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + object_layer_(execution_session_, + [this](llvm::orc::VModuleKey) { + llvm::orc::RTDyldObjectLinkingLayer::Resources result; + result.MemMgr = + std::make_shared( + orc_jit_memory_mapper::GetInstance()); + result.Resolver = symbol_resolver_; + return result; + }), + compile_layer_(object_layer_, + CompilerFunctor(target_machine_.get(), &disassembler_, + opt_level, optimize_for_size, + enable_fast_math, disable_expensive_passes, + std::move(pre_optimization_hook), + std::move(post_optimization_hook))) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } +llvm::JITSymbol SimpleOrcJIT::ResolveRuntimeSymbol(const std::string& name) { + if (const uint8* from_constant_pool = + external_constant_pool_.Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + if (func_addr == nullptr) { + return nullptr; + } + llvm::JITEvaluatedSymbol symbol_info(reinterpret_cast(func_addr), + llvm::JITSymbolFlags::None); + return symbol_info; +} + SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto key = execution_session_.allocateVModule(); @@ -194,19 +145,13 @@ void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { cantFail(compile_layer_.removeModule(key)); } -llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { - std::string mangled_name; - { - llvm::raw_string_ostream mangled_name_stream(mangled_name); - llvm::Mangler::getNameWithPrefix(mangled_name_stream, name, data_layout_); - } - +llvm::JITSymbol SimpleOrcJIT::FindCompiledSymbol(const std::string& name) { // Resolve symbol from last module to first, allowing later redefinitions of // symbols shadow earlier ones. for (auto& key : llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { if (auto symbol = - compile_layer_.findSymbolIn(key, mangled_name, + compile_layer_.findSymbolIn(key, name, /*ExportedSymbolsOnly=*/true)) { return symbol; } @@ -233,26 +178,22 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF16); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); -#ifdef TF_XLA_HAS_NEON - REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); -#endif -#ifdef TF_XLA_HAS_SSE4_1 - REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); -#endif -#ifdef TF_XLA_HAS_AVX - REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); -#endif REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + registry->Register("__gnu_f2h_ieee", reinterpret_cast(__gnu_f2h_ieee)); + registry->Register("__gnu_h2f_ieee", reinterpret_cast(__gnu_h2f_ieee)); + #undef REGISTER_CPU_RUNTIME_SYMBOL // Register both the f32 (float) and f64 (double) versions of a libm symbol. diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index 50993afc8f73617a2c65310ae73b3ab00519f550..aaeff2de8785b99d271f13b261c63118bcf7bd4a 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -46,9 +46,7 @@ namespace cpu { class SimpleOrcJIT { public: using ObjLayerT = llvm::orc::RTDyldObjectLinkingLayer; - using CompileFtor = - std::function( - llvm::Module&)>; + using CompileFtor = std::function; using CompileLayerT = llvm::orc::IRCompileLayer; using VModuleKeyT = llvm::orc::VModuleKey; @@ -89,7 +87,7 @@ class SimpleOrcJIT { // Get the runtime address of the compiled symbol whose name is given. Returns // nullptr if the symbol cannot be found. - llvm::JITSymbol FindSymbol(const std::string& name); + llvm::JITSymbol FindCompiledSymbol(const std::string& name); llvm::TargetMachine* target_machine() const { return target_machine_.get(); } @@ -98,6 +96,8 @@ class SimpleOrcJIT { } private: + llvm::JITSymbol ResolveRuntimeSymbol(const std::string& name); + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc index ec4215b4689bf822a6351522828dfb983c199305..cd1165e23812861ba9951546b7dd744529232196 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -103,15 +103,93 @@ llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { } } -llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, double low, - double high) { +llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, + const llvm::APFloat& low, + const llvm::APFloat& high) { AssertCorrectTypes({a}); llvm::Type* type = a->getType(); - CHECK_LT(low, high); + CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); CHECK(scalar_type_->isFloatingPointTy()); return llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(a, llvm::ConstantFP::get(type, low), ir_builder_), - llvm::ConstantFP::get(type, high), ir_builder_); + llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_), + GetConstantFloat(type, high), ir_builder_); +} + +llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { + bool is_vector = llvm::isa(i1->getType()); + llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); + return ir_builder()->CreateBitCast( + ir_builder()->CreateSExt(i1, integer_type, name()), + is_vector ? vector_type() : scalar_type(), name()); +} + +llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { + CHECK(scalar_type()->isFloatingPointTy()); + const llvm::DataLayout& data_layout = + ir_builder()->GetInsertBlock()->getModule()->getDataLayout(); + int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); + llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits); + if (vector) { + return llvm::VectorType::get(scalar_int_type, vector_size()); + } else { + return scalar_int_type; + } +} + +llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { + CHECK_EQ(x->getType(), scalar_type()); + return ir_builder()->CreateVectorSplat(vector_size(), x, name()); +} + +llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateAnd( + ir_builder()->CreateBitCast(lhs, int_type, name()), + ir_builder()->CreateBitCast(rhs, int_type, name()), name()), + vector_type()); +} + +llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { + AssertCorrectTypes({lhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateNot( + ir_builder()->CreateBitCast(lhs, int_type, name()), name()), + vector_type()); +} + +llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()), + ir_builder()->CreateBitCast(rhs, int_type, name()), + name()), + vector_type(), name()); } llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, @@ -292,6 +370,9 @@ std::vector VectorSupportLibrary::ComputeHorizontalSums( std::vector VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( std::vector vectors, llvm::Value* init_values) { + // vectors are N llvm vector values, each with N elements. + int64 lane_width = vectors.size(); + while (vectors.size() != 2) { std::vector new_vectors; for (int i = 0; i < vectors.size(); i += 2) { @@ -312,10 +393,14 @@ VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( high = AddInternal(ExtractHighHalf(init_values), high); } + // `low` has the first `lane_width / 2` horizontal reductions, and `high` has + // the next `lane_width / 2` horizontal reductions. + std::vector results; - for (int i = 0; i < 8; i++) { + for (int i = 0; i < lane_width; i++) { llvm::Value* scalar_result = ir_builder()->CreateExtractElement( - i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + i < (lane_width / 2) ? low : high, + ir_builder()->getInt32(i % (lane_width / 2)), name()); results.push_back(scalar_result); } diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h index 5c5d703db5b489130556e315838183fa49626649..6479bf76aab581ae3ec2923d98dab53720cab203 100644 --- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -26,6 +26,16 @@ limitations under the License. namespace xla { namespace cpu { + +// Simple wrappers around llvm::APFloat::APFloat to make the calling code more +// obvious. + +inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); } +inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) { + return llvm::APFloat(llvm::APFloat::IEEEsingle(), + llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value)); +} + // A thin wrapper around llvm_util.h to make code generating vector math flow // more readable. class VectorSupportLibrary { @@ -41,40 +51,94 @@ class VectorSupportLibrary { llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { return Mul(ir_builder()->getInt64(lhs), rhs); } - llvm::Value* Mul(double lhs, llvm::Value* rhs) { - return Mul(llvm::ConstantFP::get(rhs->getType(), lhs), rhs); + llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Mul(GetConstantFloat(rhs->getType(), lhs), rhs); } + // If your call resolved to these then you probably wanted the versions taking + // APFloat. + llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete; + llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete; + llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* Add(int64 lhs, llvm::Value* rhs) { return Add(ir_builder()->getInt64(lhs), rhs); } - llvm::Value* Add(double lhs, llvm::Value* rhs) { - return Add(llvm::ConstantFP::get(vector_type(), lhs), rhs); + llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Add(GetConstantFloat(rhs->getType(), lhs), rhs); } + // If your call resolved to these then you probably wanted the versions taking + // APFloat. + llvm::Value* Add(double lhs, llvm::Value* rhs) = delete; + llvm::Value* Add(float lhs, llvm::Value* rhs) = delete; + llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) { + return Sub(lhs, GetConstantFloat(lhs->getType(), rhs)); + } llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Max(GetConstantFloat(rhs->getType(), lhs), rhs); + } llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { return Add(c, Mul(a, b)); } - llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, double c) { - return Add(llvm::ConstantFP::get(vector_type(), c), Mul(a, b)); + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) { + return Add(GetConstantFloat(vector_type(), c), Mul(a, b)); } - llvm::Value* MulAdd(llvm::Value* a, double b, double c) { - return Add(llvm::ConstantFP::get(a->getType(), c), - Mul(a, llvm::ConstantFP::get(a->getType(), b))); + llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b, + const llvm::APFloat& c) { + return Add(GetConstantFloat(a->getType(), c), + Mul(a, GetConstantFloat(a->getType(), b))); } llvm::Value* Floor(llvm::Value* a); - llvm::Value* Clamp(llvm::Value* a, double low, double high); - llvm::Value* SplatFloat(double d) { - return llvm::ConstantFP::get(vector_type(), d); + llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low, + const llvm::APFloat& high); + llvm::Value* SplatFloat(const llvm::APFloat& d) { + return GetConstantFloat(vector_type(), d); + } + + // These compare instructions return a floating point typed mask instead of an + // i1. For instance, on a vector typed input, lanes where the predicate is + // true get a float with all ones and other lanes get a float with all zeros. + // This is slightly odd from the perspective of LLVM's type system, but it + // makes kernel IR generation code written using VectorSupportLibrary (its + // raison d'etre) less cluttered. + + llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + + // These boolean operations operate on the bitwise values of the floating + // point inputs. They return a (vector of) float(s) but like in the mask + // generating predicates above this type system oddity makes the kernel IR + // generation code less cluttered. + llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* FloatNot(llvm::Value* lhs); + llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) { + return FloatAnd(FloatNot(lhs), rhs); + } + + llvm::Value* BroadcastScalar(llvm::Value* x); + llvm::Value* BroadcastScalar(const llvm::APFloat& d) { + return BroadcastScalar(GetConstantFloat(scalar_type(), d)); } llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, @@ -194,6 +258,16 @@ class VectorSupportLibrary { std::vector ComputeAvxOptimizedHorizontalSums( std::vector vectors, llvm::Value* init_values); + llvm::Type* IntegerTypeForFloatSize(bool vector); + llvm::Value* I1ToFloat(llvm::Value* i1); + llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) { + llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); + if (llvm::isa(type)) { + return llvm::ConstantVector::getSplat(vector_size(), scalar_value); + } + return scalar_value; + } + int64 vector_size_; PrimitiveType primitive_type_; llvm::IRBuilder<>* ir_builder_; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index a803b3171f9afa6297553c5507c4f9aa45e420ab..56723e765048698baedc50ae7b189d0287ee56b8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -190,6 +190,7 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; @@ -213,6 +214,7 @@ class DfsHloVisitorBase { virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; virtual Status HandleWhile(HloInstructionPtr hlo) = 0; virtual Status HandleConditional(HloInstructionPtr hlo) = 0; + virtual Status HandleGather(HloInstructionPtr hlo) = 0; virtual Status HandlePad(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 170adb3d241b3648bc53f96dde9866f0b794f80a..ecda5288ee17a3856ce95f0caa327c3524fd180b 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -103,6 +103,9 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } + Status HandleHostCompute(HloInstructionPtr host_compute) override { + return DefaultAction(host_compute); + } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } @@ -185,6 +188,9 @@ class DfsHloVisitorWithDefaultBase Status HandleSendDone(HloInstructionPtr send_done) override { return DefaultAction(send_done); } + Status HandleGather(HloInstructionPtr gather) override { + return DefaultAction(gather); + } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4468adbadbf823f1420a8b665a26f66cb7d36b43..c732974995f70d9ba1b46e18aa4cc2c6ab467182 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -226,7 +226,7 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( if (primitive_util::IsIntegralType(to_type)) { return ir_builder_->CreateIntCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), - primitive_util::IsSignedIntegralType(to_type)); + primitive_util::IsSignedIntegralType(from_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (to_type == BF16) { @@ -1003,6 +1003,30 @@ StatusOr ElementalIrEmitter::EmitReducePrecision( ir_builder_); } +static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder, + llvm::Value* lhs, llvm::Value* rhs, + llvm::Value* shift_result, + bool saturate_to_sign_bit) { + llvm::IntegerType* integer_type = + llvm::cast(lhs->getType()); + unsigned integer_bitsize = integer_type->getBitWidth(); + llvm::ConstantInt* integer_bitsize_constant = + llvm::ConstantInt::get(integer_type, integer_bitsize); + llvm::ConstantInt* zero = llvm::ConstantInt::get(integer_type, 0); + llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1); + llvm::Value* saturated_value; + if (saturate_to_sign_bit) { + saturated_value = ir_builder->CreateSelect( + ir_builder->CreateICmpSLT(lhs, zero), minus_one, zero); + } else { + saturated_value = zero; + } + llvm::Value* shift_amt_in_range = + ir_builder->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk"); + return ir_builder->CreateSelect(shift_amt_in_range, shift_result, + saturated_value); +} + StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -1050,12 +1074,27 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); - case HloOpcode::kShiftLeft: - return ir_builder_->CreateShl(lhs_value, rhs_value); + + // Shifting out bits >= the number of bits in the type being shifted + // produces a poison value in LLVM which is basically "deferred undefined + // behavior" -- doing something observable with such a value precipitates + // UB. We replace the poison value with a constant to avoid this deferred + // UB. case HloOpcode::kShiftRightArithmetic: - return ir_builder_->CreateAShr(lhs_value, rhs_value); + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateAShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/true); + case HloOpcode::kShiftLeft: + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateShl(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); case HloOpcode::kShiftRightLogical: - return ir_builder_->CreateLShr(lhs_value, rhs_value); + return SaturateShiftIfNecessary( + ir_builder_, lhs_value, rhs_value, + ir_builder_->CreateLShr(lhs_value, rhs_value), + /*saturate_to_sign_bit=*/false); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index ed78fef4113bd9f7048ca3c8c2d4e38c5ec4762a..2029c303d47e9a62135b003c3bd9be6f8b3438d4 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -98,6 +98,14 @@ StatusOr> BufferAllocations::Builder::Build( } } + if (VLOG_IS_ON(2)) { + for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const auto& buf = buffer_allocations->buffers_[i]; + VLOG(2) << "Buffer " << i << " -> " << buf.opaque() << " (" << buf.size() + << "B)"; + } + } + return std::move(buffer_allocations); } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index f76f15929d12eed63d8964acd61fb3fea3945006..461747b699b542ae0c8735aea34cc9e57c1fb387 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -45,7 +45,7 @@ ConvolutionThunk::ConvolutionThunk( const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - const HloInstruction* hlo) + bool tensor_ops_enabled, const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), @@ -58,29 +58,30 @@ ConvolutionThunk::ConvolutionThunk( output_shape_(output_shape), window_(window), dim_nums_(dim_nums), - algorithm_(algorithm) {} + algorithm_(algorithm), + tensor_ops_enabled_(tensor_ops_enabled) {} Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { - se::DeviceMemory input_data( - buffer_allocations.GetDeviceAddress(input_buffer_)); - se::DeviceMemory filter_data( - buffer_allocations.GetDeviceAddress(filter_buffer_)); - se::DeviceMemory output_data( - buffer_allocations.GetDeviceAddress(output_buffer_)); + se::DeviceMemoryBase input_data = + buffer_allocations.GetDeviceAddress(input_buffer_); + se::DeviceMemoryBase filter_data = + buffer_allocations.GetDeviceAddress(filter_buffer_); + se::DeviceMemoryBase output_data = + buffer_allocations.GetDeviceAddress(output_buffer_); se::DeviceMemoryBase scratch = buffer_allocations.GetDeviceAddress(scratch_buffer_); se::dnn::AlgorithmConfig algorithm_config( - se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false)); + se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); TF_RETURN_IF_ERROR(RunCudnnConvolution( convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, stream)); - // Figure out which of output/input/filter is the result produced by this op, - // and write the result tuple. + // Figure out which of output/input/filter is the result produced by + // this op, and write the result tuple. void* result_ptr = [&] { switch (convolution_kind_) { case CudnnConvKind::kForward: diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index ca9ef5277b3369dea3f698d1bcf0ad190d2c5217..900d9cb6243088b56a1825fb3ab8c06cf8d74726 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -59,7 +59,7 @@ class ConvolutionThunk : public Thunk { const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, - const HloInstruction* hlo); + bool tensor_ops_enabled, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; @@ -99,6 +99,7 @@ class ConvolutionThunk : public Thunk { const Window window_; const ConvolutionDimensionNumbers dim_nums_; int64 algorithm_; + bool tensor_ops_enabled_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc index 621b2d510fa98af40b89badebef5e45902f23d4c..1792893ae401bf16d2dd9e861607e8f3821a505e 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -135,15 +135,6 @@ std::vector GetAlgorithms(CudnnConvKind kind, break; } - // Remove any algorithms with tensor math enabled. These have lower precision - // than regular algorithms, and we don't yet have a way to turn this on/off in - // XLA. - algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(), - [&](const AlgorithmDesc& a) { - return a.tensor_ops_enabled(); - }), - algorithms.end()); - return algorithms; } @@ -172,7 +163,7 @@ string NumBytesToString(int64 bytes) { // cache misses and doing extra work. Overall, caching doesn't seem worth the // trouble, but we may want to revisit this if we ever find a model where // caching would speed up compilation a lot. -optional> +optional> CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, @@ -222,6 +213,7 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); se::dnn::ProfileResult best_result; int64 best_result_bytes_used = 0; + for (const AlgorithmDesc& alg : GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { ScratchAllocator scratch_allocator(device_ordinal, allocator); @@ -229,14 +221,12 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " << instr->ToString(); - bool launch_ok = - RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - se::DeviceMemory(input_buf.ValueOrDie()), - se::DeviceMemory(filter_buf.ValueOrDie()), - se::DeviceMemory(output_buf.ValueOrDie()), - &scratch_allocator, window, dnums, - AlgorithmConfig(alg), &stream, &profile_result) - .ok(); + bool launch_ok = RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + input_buf.ValueOrDie(), filter_buf.ValueOrDie(), + output_buf.ValueOrDie(), &scratch_allocator, window, + dnums, AlgorithmConfig(alg), &stream, &profile_result) + .ok(); if (launch_ok && profile_result.is_valid()) { int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); @@ -260,8 +250,9 @@ CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( << AlgorithmToString(best_result.algorithm()) << ", takes " << best_result.elapsed_time_in_ms() << "ms, and uses " << best_result_bytes_used << "B of scratch memory."; - return std::make_pair(best_result.algorithm().algo_id(), - best_result_bytes_used); + return std::make_tuple(best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used); } LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() @@ -277,19 +268,19 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( const auto& lhs_shape = instr->operand(0)->shape(); const auto& rhs_shape = instr->operand(1)->shape(); const auto& conv_result_shape = instr->shape().tuple_shapes(0); - optional> alg_and_scratch_bytes; + optional> alg_scratch_and_tc; if (call_target == kCudnnConvForwardCallTarget) { - alg_and_scratch_bytes = PickBestAlgorithm( + alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kForward, /*input_shape=*/lhs_shape, /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, instr->window(), instr->convolution_dimension_numbers(), instr); } else if (call_target == kCudnnConvBackwardInputCallTarget) { - alg_and_scratch_bytes = PickBestAlgorithm( + alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), instr->convolution_dimension_numbers(), instr); } else if (call_target == kCudnnConvBackwardFilterCallTarget) { - alg_and_scratch_bytes = PickBestAlgorithm( + alg_scratch_and_tc = PickBestAlgorithm( CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, instr->window(), instr->convolution_dimension_numbers(), instr); @@ -298,17 +289,20 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( << instr->ToString(); } - if (!alg_and_scratch_bytes.has_value()) { + if (!alg_scratch_and_tc.has_value()) { return false; } int64 algorithm; + bool tensor_ops_enabled; int64 scratch_bytes; - std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes; + + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " << NumBytesToString(scratch_bytes) - << " of scratch memory: " << instr->ToString(); + << " of scratch memory: " << instr->ToString() + << " tensor_ops_enabled: " << tensor_ops_enabled; // Replace instr with a new CustomCall which has the correct algorithm, and // whose output shape has the appropriate amount of scratch memory. @@ -318,10 +312,15 @@ StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( ShapeUtil::MakeShape(U8, {scratch_bytes})}); HloInstruction* algorithm_hlo = computation->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); + HloInstruction* tensor_ops_enabled_hlo = + computation->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(tensor_ops_enabled))); + HloInstruction* new_call = computation->AddInstruction(HloInstruction::CreateCustomCall( new_call_shape, - {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo}, + {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, + tensor_ops_enabled_hlo}, instr->custom_call_target())); new_call->set_window(instr->window()); new_call->set_convolution_dimension_numbers( diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h index 10e49daee5df187e5ad90b7adf8c92aa9a63ba21..516210ec2e500cf03774d27408300ac3346e7b4f 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -47,7 +47,7 @@ class CudnnConvolutionAlgorithmPicker : public HloPassInterface { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - tensorflow::gtl::optional> PickBestAlgorithm( + tensorflow::gtl::optional> PickBestAlgorithm( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc index f5f52cf62bf6edb7925ec3b22fc1772ffbfbf089..e4ae839e1dd4cb3a744a3f6a3329cabdaeb3f38d 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -70,42 +70,17 @@ class ScratchBufAllocator : public se::ScratchAllocator { bool allocated_ = false; }; -} // anonymous namespace - -string CudnnConvKindToString(CudnnConvKind kind) { - switch (kind) { - case CudnnConvKind::kForward: - return "forward"; - case CudnnConvKind::kBackwardFilter: - return "backward_filter"; - case CudnnConvKind::kBackwardInput: - return "backward_input"; - } -} - -Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape, - const Shape& filter_shape, const Shape& output_shape, - DeviceMemory input_buf, - DeviceMemory filter_buf, - DeviceMemory output_buf, - DeviceMemoryBase scratch_buf, const Window& window, - const ConvolutionDimensionNumbers& dnums, - AlgorithmConfig algorithm, Stream* stream, - ProfileResult* profile_result /*= nullptr*/) { - ScratchBufAllocator scratch_allocator(scratch_buf); - return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, - input_buf, filter_buf, output_buf, - &scratch_allocator, window, dnums, algorithm, - stream, profile_result); -} - +template Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, DeviceMemory input_buf, - DeviceMemory filter_buf, DeviceMemory output_buf, + const Shape& output_shape, DeviceMemory input_buf, + DeviceMemory filter_buf, DeviceMemory output_buf, se::ScratchAllocator* scratch_allocator, const Window& window, const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); + VLOG(3) << "tensor_ops_enabled: " + << algorithm.algorithm().tensor_ops_enabled(); VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; @@ -121,8 +96,16 @@ Status RunCudnnConvolution( // tensorflow/python/ops/nn_ops.py). const int effective_num_dimensions = std::max(2, num_dimensions); - CHECK_EQ(F32, output_shape.element_type()) - << ShapeUtil::HumanString(output_shape); + if (std::is_same::value) { + CHECK_EQ(F32, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + } else if (std::is_same::value) { + CHECK_EQ(F16, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + } else { + LOG(FATAL) << ShapeUtil::HumanString(output_shape); + } + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); @@ -217,5 +200,63 @@ Status RunCudnnConvolution( return Status::OK(); } +} // anonymous namespace + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + } +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, + perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, + perftools::gputools::ScratchAllocator* scratch_allocator, + const Window& window, const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result) { + PrimitiveType output_primitive_type = output_shape.element_type(); + CHECK(output_primitive_type == F32 || output_primitive_type == F16) + << ShapeUtil::HumanString(output_shape); + if (output_primitive_type == F32) { + return RunCudnnConvolution( + kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), scratch_allocator, window, dnums, + algorithm, stream, profile_result); + } + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf), + se::DeviceMemory(filter_buf), + se::DeviceMemory(output_buf), + scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h index b101f76510c129fd22b246e5f0348848192ecbba..3dbfa2730da359d3c7937140508017c4a7b02d6c 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -55,7 +55,10 @@ string CudnnConvKindToString(CudnnConvKind kind); // Note that depending on the value of CudnnConvKind, the result of this call // may be written into input_buf, filter_buf, or output_buf! // -// At the moment we only support cudnn convolutions over floats. +// At the moment we only support cudnn convolutions over float and half, and +// convolution with half data type is implemented with cudnn PSEUDO_HALF +// configuration, that is, the input values are half and the internal +// computation type is float. // // We provide one overload which takes a scratch buffer, and another which takes // an allocator which is responsible for allocating the scratch space. In @@ -69,10 +72,9 @@ string CudnnConvKindToString(CudnnConvKind kind); // that size, if you like. Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, - perftools::gputools::DeviceMemory input_buf, - perftools::gputools::DeviceMemory filter_buf, - perftools::gputools::DeviceMemory output_buf, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, const ConvolutionDimensionNumbers& dnums, perftools::gputools::dnn::AlgorithmConfig algorithm, @@ -81,10 +83,9 @@ Status RunCudnnConvolution( Status RunCudnnConvolution( CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, - const Shape& output_shape, - perftools::gputools::DeviceMemory input_buf, - perftools::gputools::DeviceMemory filter_buf, - perftools::gputools::DeviceMemory output_buf, + const Shape& output_shape, perftools::gputools::DeviceMemoryBase input_buf, + perftools::gputools::DeviceMemoryBase filter_buf, + perftools::gputools::DeviceMemoryBase output_buf, perftools::gputools::ScratchAllocator* scratch_allocator, const Window& window, const ConvolutionDimensionNumbers& dnums, perftools::gputools::dnn::AlgorithmConfig algorithm, diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index 8e3aebbc12b5e6d746700956b9743bc94db50167..ba482793e7632f0f423cc9da0dd9620bdf29c642 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -137,9 +137,9 @@ StatusOr DoGemmAutotune( // for all algorithms if we're targeting < sm_50. But because we pass a // non-null ProfileResult, DoGemmWithAlgorithm should always return true, // and the actual success-ness is returned in ProfileResult::is_valid. - DCHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, - computation_type, algorithm, stream, - &profile_result)); + CHECK(DoGemmWithAlgorithm(lhs_matrix, rhs_matrix, output_matrix, + computation_type, algorithm, stream, + &profile_result)); if (profile_result.is_valid() && profile_result.elapsed_time_in_ms() < best_result.elapsed_time_in_ms()) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 88bf5a74fa03618d6f61365450f05e6f5d1a0c86..9db85bc788bde46c890a46ce9b0902ddce3f5675 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -49,7 +49,7 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_ASSIGN_OR_RETURN(bool changed, generic_copy_insertion.Run(module)); TF_ASSIGN_OR_RETURN(std::unique_ptr dataflow, - HloDataflowAnalysis::Run(module)); + HloDataflowAnalysis::Run(*module)); // Make sure all operands of a library call are in memory instead of constants // in IR. @@ -79,9 +79,9 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } } else if (IsCustomCallToDnnConvolution(*hlo)) { - // The last argument to a CUDNN convolution is its algorithm, which must - // be an HLO constant -- it shouldn't be copied. - for (int64 i = 0; i < hlo->operand_count() - 1; ++i) { + // The last two arguments to a CUDNN convolution are two HLO constants for + // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } } else if (ImplementedAsLibraryCall(*hlo)) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index f5d67b9ea9498df3f023ea9a694a63b468c5be18..04b37d913e0bc8f8226057f107da05fd1e675010 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -46,12 +46,14 @@ namespace { class HloExecutionProfiler { public: // If profiling is enabled, start an execution timer running. - explicit HloExecutionProfiler(bool do_profile, HloExecutionProfile* profile, - se::Stream* stream, - const HloComputation* computation) + explicit HloExecutionProfiler( + bool do_profile, HloExecutionProfile* profile, se::Stream* stream, + const std::vector::SmartPtr>& sub_streams, + const HloComputation* computation) : do_profile_(do_profile), profile_(profile), stream_(stream), + sub_streams_(sub_streams), computation_(computation) { if (do_profile_) { clock_rate_ghz_ = @@ -70,6 +72,7 @@ class HloExecutionProfiler { CHECK(!finished_execution_) << "Call FinishExecution only once!"; finished_execution_ = true; if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(execution_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( @@ -88,6 +91,7 @@ class HloExecutionProfiler { // that the hlo_instruction took to execute in the profile. void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { + stream_->ThenWaitFor(&sub_streams_); stream_->ThenStopTimer(per_op_timer_.get()); stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( @@ -100,6 +104,7 @@ class HloExecutionProfiler { double clock_rate_ghz_; HloExecutionProfile* profile_; se::Stream* stream_; + const std::vector::SmartPtr>& sub_streams_; const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; @@ -147,13 +152,9 @@ Status GpuExecutable::ExecuteThunks( LOG(WARNING) << "PROFILING: profiling is enabled"; } - HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, - hlo_module_->entry_computation()); - - uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // Stream 0 indicates `main_stream` and substreams start from stream 1. std::vector::SmartPtr> sub_streams; + sub_streams.reserve(thunk_schedule_->StreamCount() - 1); while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { sub_streams.emplace_back(); TF_ASSIGN_OR_RETURN( @@ -161,6 +162,10 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, + sub_streams, hlo_module_->entry_computation()); + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + // The next event enqueued on stream N must not run until the thunk at // last_blocking_thunk_for_stream[N] completes. std::map last_blocking_thunk_for_stream; @@ -262,9 +267,16 @@ StatusOr> GpuExecutable::ExecuteOnStream( ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { - auto param_no = allocation.parameter_number(); - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->root_buffer()); + // The caller must give us a buffer for ShapeIndex {} of every parameter. + // It can optionally give us a buffer for other ShapeIndices, but we + // ignore them: Because we can't rely on these sub-buffers' addresses + // being available, our generated code can't use them. Instead, it must + // chase pointers starting at the tuple root. + if (allocation.param_shape_index().empty()) { + auto param_no = allocation.parameter_number(); + buffer_allocations_builder.RegisterBuffer( + i, arguments[param_no]->root_buffer()); + } } } se::StreamExecutor* executor = run_options->stream()->parent(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 1fe7970e7d94ad4a4cad6aabcfc84a1356753443..3d34311b4368d17cb074aaf33c71fc865e96387e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -66,13 +66,14 @@ class HloToIrBindings { } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } + void SetTempBufferBase(llvm::Value* v) { temp_buffer_base_ = v; } // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, const ShapeIndex& shape_index = {}) const { auto it = base_ptrs_.find(&hlo); - CHECK(it != base_ptrs_.end()); + CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); } @@ -113,7 +114,7 @@ class HloToIrBindings { std::unordered_map> base_ptrs_; // The address of the memory block that contains all temporary buffers. - llvm::Value* temp_buffer_base_; + llvm::Value* temp_buffer_base_ = nullptr; llvm_ir::AliasAnalysis alias_analysis_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 7ad9680bfb4a2ec0d43e2fe86fd138a4a46e2935..59455f389e733fee2d6cace7486f919a0c5e834e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -63,10 +63,11 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); // strings. // // These CustomCalls have window() and convolution_dimension_numbers() set like -// regular convolution ops. They have the same LHS and RHS operands, plus one -// additional int64 operand, representing which cudnn algorithm to run. This -// operand must be an HLO constant. A value of -1 means that the implementation -// is free to choose the best algorithm it can. +// regular convolution ops. They have the same LHS and RHS operands, plus two +// additional constant operands: an int64 operand for the cudnn algorithm and +// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn +// algorithm means that the implementation is free to choose the best algorithm +// it can. // // These calls output a tuple (conv_result, scratch_memory), where conv_result // is the actual result of the convolution, and scratch_memory is temporary diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c81dfbf6c2a34aeb6d92ded23b8e264ebec30d54..30c88c0a5d38f6ea3f94d3b47b7b69c7122bf6ac 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -75,6 +75,10 @@ namespace gpu { namespace { using llvm_ir::IrName; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; +using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -137,6 +141,38 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } + +// Tries to get a Slice for the given instruction at the given index, but +// returns nullopt if we might not know the slice's address at runtime without +// dereferencing a containing tuple. +// +// In particular, when XLA accepts a parameter of tuple type, the caller has the +// option of telling XLA what are the values inside of the tuple, or just giving +// XLA a pointer to the top-level tuple and letting us chase the pointers on the +// GPU. We therefore cannot rely having these pointers to parameter sub-buffers +// being present when we run the program. +optional GetKnownAtRuntimeSlice( + const HloInstruction* instr, const ShapeIndex& index, + const BufferAssignment& buffer_assn) { + auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); + if (!maybe_slice.ok()) { + return nullopt; + } + // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, + // but we don't necessarily know the runtime address of sub-buffers of input + // parameters. + const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); + const BufferAllocation* alloc = slice.allocation(); + if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && + !alloc->param_shape_index().empty()) { + return nullopt; + } + + // Otherwise, we will know the address of this slice at runtime without having + // to dereference a tuple. + return slice; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -154,16 +190,20 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } namespace { -bool ImplementedAsHostToDeviceMemcpy(const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a +bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, + const HloInstruction& hlo) { + // `hlo` needs to satisfy the following conditions to be implemented as a // host-to-device cuMemcpy. // // 1. `hlo` is a kCopy instruction. // 2. `hlo`'s only operand is a kConstant instruction. // 3. `hlo` and its operand have the same shape (thus the same layout too). + // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing + // pointers in a tuple). return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()); + ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && + GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -177,13 +217,15 @@ bool ImplementedAsDeviceToDeviceMemcpy( // instance) which means the source buffer also resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.HasTopLevelAllocation(hlo.operand(0)); + GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && + GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) + .has_value(); } } // namespace llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos) { + tensorflow::gtl::ArraySlice args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( @@ -192,43 +234,32 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); llvm::LLVMContext& context = module->getContext(); - int num_escaped_hlos = escaped_hlos.size(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(context), - std::vector(num_escaped_hlos + 1, - ir_builder_.getInt8PtrTy()), + std::vector(args.size(), ir_builder_.getInt8PtrTy()), /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, kernel_name.c_str(), module); - // Add dereferenceable information to each of the escaped HLO parameters. - for (size_t arg_no = 0; arg_no < escaped_hlos.size(); ++arg_no) { - const HloInstruction* escaped_hlo = escaped_hlos[arg_no]; - const Shape& escaped_hlo_shape = escaped_hlo->shape(); - int64 escaped_hlo_size = llvm_ir::ByteSizeOf( - escaped_hlo_shape, ir_emitter_context_->llvm_module()->getDataLayout()); - kernel->addDereferenceableAttr(arg_no + 1, escaped_hlo_size); - } - - // The last argument is a pointer to the temporary buffer memory block. - // We know that it doesn't alias any of the escaped arguments (the inputs + - // the result). We also know how many bytes can be dereferenced in it. - const llvm::Argument& temp_buffer = *std::prev(kernel->arg_end()); - int64 temp_buffer_arg_no = temp_buffer.getArgNo(); - int64 temp_allocation_total_size = - ir_emitter_context_->buffer_assignment().temp_allocation_total_size(); - if (temp_allocation_total_size != 0) { - kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, - temp_allocation_total_size); - } - kernel->addParamAttr(temp_buffer_arg_no, llvm::Attribute::NoAlias); + // Add dereferenceable and alignment information to each of the kernel's + // parameters. + auto arg_it = kernel->arg_begin(); + for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) { + const BufferAllocation* alloc = args[arg_no]; + llvm::Argument* fn_arg = &*arg_it; + ++arg_it; - // All arguments to a kernel must be aligned to kCudaMallocAlignBytes. - for (int64 i = 0; i < kernel->arg_size(); ++i) { + kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); kernel->addParamAttr( - i, llvm::Attribute::get(context, llvm::Attribute::Alignment, - kCudaMallocAlignBytes)); + arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, + kCudaMallocAlignBytes)); + + if (alloc->IsPreallocatedTempBuffer()) { + fn_arg->setName("temp_buf"); + } else { + fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index()))); + } } // TODO(b/65380986): Investigate if adding fast math flags for generated @@ -245,10 +276,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Update the insert point to the entry basic block. llvm::BasicBlock* entry_bb = - llvm::BasicBlock::Create(context, - "entry", // The name of the basic block. - kernel); // The parent/owner of "entry_bb". - // Emit a "return void" at entry_bb's end, and sets the insert point before + llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel); + + // Emit a "return void" at entry_bb's end, and set the insert point before // that return instruction. ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); @@ -393,6 +423,11 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); int64 algorithm = algorithm_inst->literal().Get({}); + const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); + CHECK(tensor_ops_enabled_inst->IsConstant()) + << tensor_ops_enabled_inst->ToString(); + bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); + const auto& target = custom_call->custom_call_target(); std::unique_ptr thunk; if (target == kCudnnConvForwardCallTarget) { @@ -407,7 +442,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, custom_call); + algorithm, tensor_ops_enabled, custom_call); } else if (target == kCudnnConvBackwardInputCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardInput, @@ -420,7 +455,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, custom_call); + algorithm, tensor_ops_enabled, custom_call); } else if (target == kCudnnConvBackwardFilterCallTarget) { thunk = MakeUnique( CudnnConvKind::kBackwardFilter, @@ -433,7 +468,7 @@ Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, // custom_call->window(), custom_call->convolution_dimension_numbers(), - algorithm, custom_call); + algorithm, tensor_ops_enabled, custom_call); } else { LOG(FATAL) << "Unexpected custom call target: " << custom_call->custom_call_target(); @@ -864,7 +899,8 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(*copy)) { + if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), + *copy)) { thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); return Status::OK(); } @@ -1926,62 +1962,207 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } -llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos) { - const BufferAssignment& buffer_assignment = - ir_emitter_context_->buffer_assignment(); - // GetTupleElement instructions are implemented by emitting IR that indexes - // and loads the target tuple element pointer from its operand (possibly - // recursively). For this reason, GetTupleElement instructions are associated - // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. - std::vector non_io_hlos; - for (const HloInstruction* operand : hlo.operands()) { - const HloInstruction* to_lookup = operand->LatestNonGteAncestor(); - if (buffer_assignment.HasTopLevelAllocation(to_lookup) && - buffer_assignment.GetUniqueTopLevelSlice(to_lookup) - .ConsumeValueOrDie() - .allocation() - ->IsInputOrOutput()) { - io_hlos->push_back(operand); - } else { - non_io_hlos.push_back(operand); +// Figures out how to access the buffers for all subshapes of hlo's operands and +// for hlo itself (i.e. all the buffers produced by HLO). +// +// Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for +// this key is a pair {Slice, ShapeIndex}, where the slice tells you the root +// buffer to look in, and the ShapeIndex describes how to dereference starting +// at that buffer to get to the buffer in question. +// +// For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for +// hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) +// is found at slice[3][4]. That is, slice is a void***, which we dereference +// twice -- first at index 3, and then at index 4 -- to get the address of our +// buffer. +// +// This function conservatively assumes that we'll touch all sub-buffers of +// every operand and of the output. +static std::map, + std::pair> +GetHloBufferSlices(const HloInstruction* hlo, + const BufferAssignment& buffer_assn) { + std::map, + std::pair> + slices; + + // Tries to find a slice plus an array of indices i1, ..., iN such that the + // sub-buffer for instr at index can be found at slice[i1]...[iN]. + auto find_slice_for = [&](const HloInstruction* instr, + const ShapeIndex& index) + -> optional> { + // Simple, common case: Is the buffer for instr known at runtime? If so, + // we're done. + auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); + if (slice.has_value()) { + return {{*slice, ShapeIndex()}}; } - } - CHECK_NE(HloOpcode::kGetTupleElement, hlo.opcode()); - if (buffer_assignment.HasTopLevelAllocation(&hlo) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo) - .ConsumeValueOrDie() - .allocation() - ->IsInputOrOutput()) { - io_hlos->push_back(&hlo); - } else { - non_io_hlos.push_back(&hlo); + // If we don't know the buffer for instr at index, see if we know the buffer + // for instr at index without its last element. If so, we can dynamically + // find the buffer for instr by dereferencing a pointer in that buffer. + // Continue looking this way until we run out of elements in 'index'. + ShapeIndex new_index = index; + ShapeIndex gte_indices; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + // If *that* didn't work, check whether instr is a GTE instruction. If it + // is, see if we can get a buffer for its parent, and continue walking up + // parents until we find a defined buffer or we hit something that's not a + // GTE. + const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kGetTupleElement) { + gte_indices.push_front(parent->tuple_index()); + parent = parent->operand(0); + + auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + return nullopt; + }; + + // Adds entries for all subshapes of instr to `slices`. + auto add_slices_for = [&](const HloInstruction* instr) { + // GPU constants don't have buffers; don't bother looking for one. + if (instr->IsConstant()) { + return; + } + + ShapeUtil::ForEachSubshape( + instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { + if (slices.count({instr, index})) { + // HLOs can have duplicate operands; don't bother redoing work. + return; + } + auto maybe_slice = find_slice_for(instr, index); + if (maybe_slice.has_value()) { + slices[{instr, index}] = *maybe_slice; + } else { + VLOG(1) << "Couldn't find buffer for " << instr->ToString() + << " at index " << index.ToString(); + } + }); + }; + + add_slices_for(hlo); + for (const HloInstruction* operand : hlo->operands()) { + // Conservatively assume we'll need the buffers for all subshapes of the + // operand. + add_slices_for(operand); } - llvm::Function* kernel = BuildKernelPrototype(hlo, *io_hlos); - // bindings_ is reused because the bindings of kConstant to their underlying - // llvm::Constant can be shared for all HLOs in this computation. - bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); - return kernel; + return slices; +} + +Status IrEmitterUnnested::HandleGather(HloInstruction* gather) { + // TODO(b/72710576): Gather is not implemented on GPUs + return Unimplemented("Gather is not implemented on GPUs."); } std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst) { - std::vector io_hlos; - llvm::Function* kernel = - EmitBasePointersForHloAndItsOperands(*inst, &io_hlos); + const BufferAssignment& buffer_assn = + ir_emitter_context_->buffer_assignment(); + + std::map, + std::pair> + hlo_slices = GetHloBufferSlices(inst, buffer_assn); + + // Figure out which buffer allocations need to be passed as arguments to our + // kernel. This is simply all of the allocations referenced in hlo_slices, + // plus the XLA temp buffer (if we have it). We always include the temp + // buffer because even if the kernel itself doesn't use it, a nested + // subcomputation within the kernel (e.g. a kMap's computation) might. + std::unordered_set buffers_needed; + for (const auto& kv : hlo_slices) { + buffers_needed.insert(kv.second.first.allocation()); + } + tensorflow::gtl::optional temp_buffer; + for (const BufferAllocation& alloc : buffer_assn.Allocations()) { + if (alloc.IsPreallocatedTempBuffer()) { + if (!temp_buffer.has_value()) { + temp_buffer = &alloc; + } else { + LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!"; + } + } + } + if (temp_buffer.has_value()) { + buffers_needed.insert(*temp_buffer); + } + + // We'll pass a pointer to each of the elements of `buffers` to our kernel, in + // this order. + std::vector buffers(buffers_needed.begin(), + buffers_needed.end()); + std::sort(buffers.begin(), buffers.end(), + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); + + llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); - // Compute the input buffer indices. - std::vector io_buffers; - io_buffers.reserve(io_hlos.size()); - for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationSlice(*io_hlo->LatestNonGteAncestor())); + // Build a map from a BufferAllocation to the corresponding argument in our + // kernel. + std::unordered_map kernel_args; + { + auto arg_it = kernel->arg_begin(); + auto buffers_it = buffers.begin(); + for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { + kernel_args[*buffers_it] = arg_it; + } + } + + // For each buffer our kernel might want to touch, bind it to a value derived + // from our kernel args. + for (const auto& kv : hlo_slices) { + const HloInstruction* instr = kv.first.first; + const ShapeIndex& index = kv.first.second; + const BufferAllocation::Slice& slice = kv.second.first; + const ShapeIndex& gte_index = kv.second.second; + + VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() + << " is found in slice " << slice.ToString() << " at GTE index " + << gte_index.ToString(); + + llvm::Value* loc = + ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {ir_builder_.getInt64(slice.offset())}); + + // If gte_index is nonempty, we have to dereference `loc` to get to the + // value we're ultimately interested in. + llvm::Type* int8_double_pointer = + llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); + for (int64 idx : gte_index) { + loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); + loc = ir_builder_.CreateLoad( + ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); + } + + bindings_.BindHloToIrValue(*instr, loc, index); + } + + // Bind the temp buffer so that nested subcomputations can find it if they + // need. + if (temp_buffer.has_value()) { + bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); + } else { + bindings_.SetTempBufferBase( + llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); } - // Create a KernelThunk that launches the kernel that implements "inst". - return MakeUnique(io_buffers, - llvm_ir::AsString(kernel->getName()), inst); + return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), + inst); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 56ab8208cee6f53afce365baa213fd2f5a6425a0..b83a2337e2decd9d4fba3d40fcf33f131fca8a3c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -67,6 +67,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleDot(HloInstruction* dot) override; Status HandleFft(HloInstruction* fft) override; Status HandleFusion(HloInstruction* fusion) override; + Status HandleGather(HloInstruction* gather) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; @@ -93,14 +94,10 @@ class IrEmitterUnnested : public IrEmitter { std::unique_ptr BuildThunk(const HloInstruction* hlo); // Builds the prototype of the IR kernel for `inst` and adds it to the module. + // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos); - - // Emits the base pointers for `hlo` and its operands. `io_hlos` will store - // all input/output HLOs among `hlo` and its operands. - llvm::Function* EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos); + tensorflow::gtl::ArraySlice args); // EmitColumnReduction and EmitRowReduction emit code for column and row // reduction of a matrix and/or 3D tensor. Row and column reduction have diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 96606993696354f36e143b3b994bbe6afb902df3..c20a781a33fe89af4740ed31dd5bfb1a64473057 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -29,10 +29,10 @@ namespace xla { namespace gpu { KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice io_buffers, + tensorflow::gtl::ArraySlice args, const string& kernel_name, const HloInstruction* hlo_instruction) : Thunk(Kind::kKernel, hlo_instruction), - io_buffers_(io_buffers.begin(), io_buffers.end()), + args_(args.begin(), args.end()), kernel_name_(kernel_name) {} tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { @@ -42,7 +42,7 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { return tensorflow::Status::OK(); } - loader_spec_.reset(new se::MultiKernelLoaderSpec(io_buffers_.size() + 1)); + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); tensorflow::StringPiece ptx = executable.ptx(); // Convert tensorflow::StringPiece to se::port::StringPiece because // StreamExecutor uses the latter. @@ -81,15 +81,16 @@ tensorflow::Status KernelThunk::ExecuteOnStream( kernel = &it->second; } + VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; auto kernel_args = MakeUnique>(); - for (const BufferAllocation::Slice io_buffer : io_buffers_) { - kernel_args->add_device_memory_argument( - buffer_allocations.GetDeviceAddress(io_buffer)); + for (const BufferAllocation* arg : args_) { + const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); + kernel_args->add_device_memory_argument(buf); + VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" + << buf.size() << "B)"; } - kernel_args->add_device_memory_argument( - buffer_allocations.GetTempBufferBase()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 350b5aaf360b0dad7f7b04d73f4c32bad55d3ce9..9ae455e2fcc253a7a08ff95764721048a16b0bf7 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -46,7 +46,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice io_buffers, + KernelThunk(tensorflow::gtl::ArraySlice args, const string& kernel_name, const HloInstruction* hlo_instruction); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; @@ -63,8 +63,8 @@ class KernelThunk : public Thunk { perftools::gputools::Stream* stream) override; private: - // The indices of the input/output buffers. - const std::vector io_buffers_; + // Buffers passed to the kernel as arguments. + const std::vector args_; // Entry kernel name for the computation. const string kernel_name_; diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index cfabae791d26d0eb49826085ad7ad166a19109a1..defd281d74bd38f7da3f268e0f55970fc1af8263 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -252,7 +252,7 @@ void EmitBitcodeToFile(const Module& module, tensorflow::StringPiece filename) { LOG(FATAL) << "opening bitcode file for writing: " << error_code.message(); } - llvm::WriteBitcodeToFile(&module, outfile.os()); + llvm::WriteBitcodeToFile(module, outfile.os()); outfile.keep(); } diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index cde5877e29f36abc61c5417ce960e2c7699e2749..3dd4c4a0794e5c41b877078c4e69c6c9584ce6c0 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -27,38 +27,6 @@ namespace xla { using tensorflow::gtl::FlatMap; using tensorflow::gtl::FlatSet; -namespace { - -// Returns the set of buffers that may be sources of all operands of the given -// instruction. The returned buffers are guaranteed to have no duplicates, and -// to be sorted in a deterministic order. -std::vector UniqueOperandSourceBuffers( - const HloInstruction* instruction, - const TuplePointsToAnalysis& points_to_analysis) { - std::vector buffers; - for (const HloInstruction* operand : instruction->operands()) { - points_to_analysis.GetPointsToSet(operand).ForEachElement( - [&](const ShapeIndex& /*index*/, - const PointsToSet::BufferList& points_to) { - buffers.insert(buffers.end(), points_to.begin(), points_to.end()); - }); - } - - // Sort and then remove duplicates from buffers. - std::sort(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() < b->id(); - }); - buffers.erase(std::unique(buffers.begin(), buffers.end(), - [](const LogicalBuffer* a, const LogicalBuffer* b) { - return a->id() == b->id(); - }), - buffers.end()); - return buffers; -} - -} // namespace - /*static*/ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, @@ -93,6 +61,7 @@ Status HeapSimulator::RunComputation( const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis) { + VLOG(3) << "Computation:\n" << computation.ToString(); // The goal here is to minimize memory usage, assuming the given sequential // ordering of instructions. The strategy is to walk through the instruction // sequence, calling Alloc and Free on the underlying heap algorithm. The @@ -101,7 +70,51 @@ Status HeapSimulator::RunComputation( // 'live_buffers' tracks the liveness of each buffer that we assign, by // associating it with a set of HloInstructions that need to be visited. When // the set becomes empty, the buffer is no longer used, and can be freed. + // 'used_buffers' is the reverse map - it tracks which buffers were used by an + // instruction, so that we can remove the instructions from a buffer's live + // set after they are visited. FlatMap> live_buffers; + FlatMap> used_buffers; + auto add_user_to_buffer = [this, &live_buffers, &used_buffers]( + const HloInstruction* user, + const LogicalBuffer* buffer) { + if (!IgnoreBuffer(buffer)) { + VLOG(4) << " Adding user " << user->name() << " to buffer " + << buffer->ToString(); + live_buffers[buffer].insert(user); + used_buffers[user].insert(buffer); + } + }; + + // Initialize live_buffers for each buffer that we're going to assign. The + // set of instructions that need to be visited contains all users of all + // aliases, that is, all users of all instructions that have the buffer + // contained in their points-to set. + for (const HloInstruction* instruction : instruction_sequence) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction); + const PointsToSet::BufferSet& buffer_set = points_to.CreateFlattenedSet(); + for (const HloInstruction* user : instruction->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + for (const LogicalBuffer* buffer : buffer_set) { + add_user_to_buffer(user, buffer); + } + } else { + // A GetTupleElement doesn't need to keep all of its operand's buffers + // alive. It only needs the buffers that relate to the element its + // extracting, and the tuple it's extracting from, but not the buffers + // for the other elements. + for (const LogicalBuffer* buffer : points_to.element({})) { + add_user_to_buffer(user, buffer); + } + const PointsToSet& gte_points_to = + points_to_analysis.GetPointsToSet(user); + for (const LogicalBuffer* buffer : gte_points_to.CreateFlattenedSet()) { + add_user_to_buffer(user, buffer); + } + } + } + } const HloInstruction* root = computation.root_instruction(); auto output_source_buffers = @@ -114,34 +127,17 @@ Status HeapSimulator::RunComputation( buffers_defined_by_instruction = points_to_analysis.GetBuffersDefinedByInstruction(instruction); - // Initialize live_buffers for each buffer that we're going to assign. The - // set of instructions that need to be visited contains all users of all - // aliases. The alias itself is not necessary; if it has users, the users - // are necessarily scheduled after the alias. And if it has no users, it is - // either a dead value or an output, both of which are handled below. - // - // We ignore control dependencies here. The reasoning is that the control - // dependencies have already been accounted for in the ordering of the given - // 'instruction_sequence', and should not otherwise artificially extend the - // lifetime of buffers that aren't already connected by a data dependency. + VLOG(3) << "Instruction: " << instruction->ToString(); + for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { + VLOG(4) << " Defines: " << buffer->ToString() + << (IgnoreBuffer(buffer) ? " (Ignored)" : ""); + } + dead_buffers_to_free.clear(); for (const LogicalBuffer* buffer : buffers_defined_by_instruction) { if (IgnoreBuffer(buffer)) { continue; } - FlatSet* live_set = nullptr; - for (const BufferAlias& alias : - points_to_analysis.GetBufferAliases(*buffer)) { - const std::vector& users = - alias.instruction()->users(); - if (!users.empty()) { - if (live_set == nullptr) { - live_set = &live_buffers[buffer]; - } - live_set->insert(users.begin(), users.end()); - } - } - // Add a nullptr sentry to ensure entry parameters and output source // buffers are not freed until the very end. const bool entry_parameter = @@ -165,11 +161,12 @@ Status HeapSimulator::RunComputation( // have no instructions left to visit are moved from live_buffers to // operand_buffers_to_free. operand_buffers_to_free.clear(); - for (const LogicalBuffer* operand_buffer : - UniqueOperandSourceBuffers(instruction, points_to_analysis)) { + for (const LogicalBuffer* operand_buffer : used_buffers[instruction]) { if (IgnoreBuffer(operand_buffer)) { continue; } + VLOG(4) << " Removing user " << instruction->name() << " from buffer " + << operand_buffer->ToString(); auto it = live_buffers.find(operand_buffer); FlatSet* live_set = &it->second; live_set->erase(instruction); @@ -178,6 +175,11 @@ Status HeapSimulator::RunComputation( operand_buffers_to_free.push_back(operand_buffer); } } + // Sort to get a deterministic iteration order. + std::sort(operand_buffers_to_free.begin(), operand_buffers_to_free.end(), + [](const LogicalBuffer* x, const LogicalBuffer* y) { + return x->id() < y->id(); + }); // Allocate buffers defined by this instruction. This is the latest point // that we can allocate; right before the buffer is first used. This must @@ -203,6 +205,8 @@ Status HeapSimulator::RunComputation( CanShareOperandBufferWithUser( operand_buffer->instruction(), operand_buffer->index(), buffer->instruction(), buffer->index(), points_to_analysis)) { + VLOG(3) << " Sharing: " << buffer->ToString() << " with " + << operand_buffer->ToString(); ShareBuffer(buffer, operand_buffer, instruction); shared = true; break; @@ -211,6 +215,7 @@ Status HeapSimulator::RunComputation( } if (!shared) { + VLOG(3) << " Allocating: " << buffer->ToString(); Alloc(buffer, instruction); } } @@ -225,6 +230,7 @@ Status HeapSimulator::RunComputation( // sub-computations will never be run concurrently. if (module_sequence_ != nullptr) { if (instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kConditional || instruction->opcode() == HloOpcode::kWhile) { for (const HloComputation* called_computation : instruction->called_computations()) { @@ -243,20 +249,34 @@ Status HeapSimulator::RunComputation( // Free buffers that are no longer live. This is the earliest point that we // can de-allocate; right after the last use of the buffer. for (const LogicalBuffer* buffer : dead_buffers_to_free) { + VLOG(3) << " Freeing dead: " << buffer->ToString(); Free(buffer, instruction); } for (const LogicalBuffer* buffer : operand_buffers_to_free) { + VLOG(3) << " Freeing operand: " << buffer->ToString(); Free(buffer, instruction); } } // Any remaining live buffers must be entry parameters or output source - // buffers, which had a nullptr sentry added. Free them now. + // buffers, which had a nullptr sentry added. Free them now, in a + // deterministic order. + std::vector to_free; + to_free.reserve(live_buffers.size()); for (const auto& buffer_pending : live_buffers) { const LogicalBuffer* buffer = buffer_pending.first; const FlatSet& pending = buffer_pending.second; CHECK_EQ(pending.size(), 1) << *buffer; CHECK(*pending.begin() == nullptr) << *buffer; + to_free.push_back(buffer); + } + + std::sort(to_free.begin(), to_free.end(), + [](const LogicalBuffer* x, const LogicalBuffer* y) { + return x->id() < y->id(); + }); + for (const LogicalBuffer* buffer : to_free) { + VLOG(3) << "Freeing pending: " << buffer->ToString(); Free(buffer, root); } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 387b649a731ebcbfd8307807469f39f22d192b06..688a271712ac243666ba4ff02932aa4f7f7ed21c 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -410,6 +410,56 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { }); } +TEST_F(HeapSimulatorTest, IndependentTupleElements) { + auto builder = HloComputation::Builder(TestName()); + auto paramA = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32scalar_, "paramA")); + auto paramB = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32scalar_, "paramB")); + auto mul = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kMultiply, paramA, paramB)); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kAdd, paramA, paramB)); + auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({mul, add})); + auto element0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 0)); + auto broadcast = builder.AddInstruction( + HloInstruction::CreateBroadcast(f32vec4_, element0, {0})); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + f32scalar_, HloOpcode::kSubtract, paramA, paramB)); + auto element1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32scalar_, tuple, 1)); + auto output = builder.AddInstruction( + HloInstruction::CreateTuple({broadcast, sub, element1})); + + HeapSimulatorTracker tracker(TestName(), builder.Build(), + {paramA, paramB, mul, add, tuple, element0, + broadcast, sub, element1, output}); + tracker.ExpectCallSequence({ + {kAlloc, tracker.BufferAt(paramA, {})}, + {kAlloc, tracker.BufferAt(paramB, {})}, + {kAlloc, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(add, {})}, + {kAlloc, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(broadcast, {})}, + // The mul can be freed right after the broadcast happens, even though + // The other GetTupleElement is still alive. + {kFree, tracker.BufferAt(mul, {})}, + {kAlloc, tracker.BufferAt(sub, {})}, + // The temporary tuple is now dead. + {kFree, tracker.BufferAt(tuple, {})}, + {kAlloc, tracker.BufferAt(output, {})}, + // All params and outputs are freed at the end. + {kFree, tracker.BufferAt(paramA, {})}, + {kFree, tracker.BufferAt(paramB, {})}, + {kFree, tracker.BufferAt(add, {})}, + {kFree, tracker.BufferAt(broadcast, {})}, + {kFree, tracker.BufferAt(sub, {})}, + {kFree, tracker.BufferAt(output, {})}, + {kFinish, nullptr}, + }); +} + TEST_F(HeapSimulatorTest, WholeModule) { HeapSimulatorTracker tracker(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0e9a852788e978f79fa6f6c802f855a4c476583f..a43785b4a9701369ae315f67d4d64d03dc6c081d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -129,6 +129,10 @@ message HloInstructionProto { // FFT length. repeated int64 fft_length = 32; + + // Gather dimension numbers. + xla.GatherDimensionNumbers gather_dimension_numbers = 33; + repeated int64 gather_window_bounds = 34; } // Serialization of HloComputation. @@ -200,6 +204,7 @@ message BufferAllocationProto { bool is_reusable = 4; bool is_entry_computation_parameter = 5; int64 parameter_number = 6; + repeated int64 parameter_shape_index = 10; bool maybe_live_out = 7; int64 color = 8; repeated Assigned assigned = 9; diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 6d2a3aa5b531650a658502531e050702ffbd3760..30e32a46d7dd0923f738939c33407ac7484b5bbe 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -419,7 +419,7 @@ StatusOr> HloAliasAnalysis::Run( auto alias_analysis = WrapUnique(new HloAliasAnalysis(module)); TF_ASSIGN_OR_RETURN( alias_analysis->dataflow_analysis_, - HloDataflowAnalysis::Run(module, /*ssa_form=*/true, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true, /*bitcast_defines_value=*/false)); BufferValueMap buffer_map(alias_analysis->dataflow_analysis()); diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 5432419e4a2dd2916da32ac6566851bf52fd68ca..21e6b2ca730f6347af902097e6496826b861e8a3 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -509,13 +509,14 @@ StatusOr HloComputation::DeepCopyInstruction( "Can't deep copy instruction %s: instruction is not in computation %s", instruction->name().c_str(), name().c_str()); } - if (indices_to_copy != nullptr && !ShapeUtil::Compatible(instruction->shape(), indices_to_copy->shape())) { return FailedPrecondition( "Can't deep copy instruction %s: given shape tree of indices to copy " - "has incompatible shape", - instruction->name().c_str()); + "has incompatible shapes: %s vs. %s", + instruction->name().c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + ShapeUtil::HumanString(indices_to_copy->shape()).c_str()); } ShapeIndex index; diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 061c59abe5e315917161ed737f89de53d71bb1b6..39d864efcb70382b6f8e631d7e6e452ea6410104 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -77,6 +77,14 @@ class HloComputation { return last_added_instruction_; } + Status ForEachInstruction( + const std::function& func) const { + for (const auto& instruction : instructions_) { + TF_RETURN_IF_ERROR(func(instruction.get())); + } + return Status::OK(); + } + private: const string name_; HloInstruction* last_added_instruction_; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 9cd5a1e2b71a7aa768e478289e8e4cc13030fcc3..4ec2ef27bf59b0c877ec38e55ef5c12debeec227 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -229,6 +229,10 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, @@ -529,6 +533,11 @@ Status HloCostAnalysis::HandleConditional(const HloInstruction* conditional) { return Status::OK(); } +Status HloCostAnalysis::HandleGather(const HloInstruction* gather) { + // Gather does not issue any flops. + return Status::OK(); +} + Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index e5783539e5436f09fa58bf7889118380ee90fea0..d17678d20f2a23fd98d18b77d5fb25853901a789 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; @@ -99,6 +100,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleTranspose(const HloInstruction* transpose) override; Status HandleWhile(const HloInstruction* xla_while) override; Status HandleConditional(const HloInstruction* conditional) override; + Status HandleGather(const HloInstruction* gather) override; Status FinishVisit(const HloInstruction* root) override; Status Preprocess(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index d25fc5d7418ae40c7167f88d6172906482a58925..934e43ba4879628362009267c671ec4cb0d79c52 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -38,12 +38,12 @@ namespace xla { using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; -HloDataflowAnalysis::HloDataflowAnalysis(HloModule* module, bool ssa_form, +HloDataflowAnalysis::HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value) : module_(module), ssa_form_(ssa_form), bitcast_defines_value_(bitcast_defines_value), - call_graph_(CallGraph::Build(module)) {} + call_graph_(CallGraph::Build(&module)) {} bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction, const ShapeIndex& index) const { @@ -115,9 +115,9 @@ void HloDataflowAnalysis::DeleteMarkedValues() { } string HloDataflowAnalysis::ToString() const { - string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); + string out = StrCat("HloDataflowAnalysis, module ", module_.name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { @@ -585,16 +585,23 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue worklist; + tensorflow::gtl::FlatSet workset; + auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { + if (workset.insert(instruction).second) { + worklist.push(instruction); + } + }; - for (HloComputation* computation : module_->computations()) { + for (HloComputation* computation : module_.computations()) { for (HloInstruction* instruction : computation->instructions()) { - worklist.push(instruction); + add_to_worklist(instruction); } } while (!worklist.empty()) { HloInstruction* instruction = worklist.front(); worklist.pop(); + workset.erase(workset.find(instruction)); VLOG(3) << "Worklist top: " << instruction->name(); VLOG(3) << ToString(); @@ -608,9 +615,10 @@ void HloDataflowAnalysis::Propagate() { VLOG(4) << "New value set for " << instruction->name() << ": " << GetInstructionValueSet(instruction); - // Instruction value was updated. Add users to work list. + // Instruction value was updated. Add users to work list if we haven't + // already. for (HloInstruction* user : instruction->users()) { - worklist.push(user); + add_to_worklist(user); // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. @@ -625,10 +633,10 @@ void HloDataflowAnalysis::Propagate() { // Note that the same instruction can be used in both operand 1 and // operand 2. if (user->operand(1) == instruction) { - worklist.push(user->true_computation()->parameter_instruction(0)); + add_to_worklist(user->true_computation()->parameter_instruction(0)); } if (user->operand(2) == instruction) { - worklist.push(user->false_computation()->parameter_instruction(0)); + add_to_worklist(user->false_computation()->parameter_instruction(0)); } } else { for (HloComputation* called_computation : user->called_computations()) { @@ -636,7 +644,7 @@ void HloDataflowAnalysis::Propagate() { call_graph_->GetNode(called_computation); if (call_graph_node.context() == CallContext::kSequential) { for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( + add_to_worklist( called_computation->parameter_instruction(operand_number)); } } @@ -652,13 +660,13 @@ void HloDataflowAnalysis::Propagate() { for (const CallSite& callsite : call_graph_node.caller_callsites()) { if ((callsite.instruction()->opcode() == HloOpcode::kCall) || (callsite.instruction()->opcode() == HloOpcode::kConditional)) { - worklist.push(callsite.instruction()); + add_to_worklist(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. - worklist.push(callsite.instruction()); - worklist.push( + add_to_worklist(callsite.instruction()); + add_to_worklist( callsite.instruction()->while_body()->parameter_instruction(0)); - worklist.push( + add_to_worklist( callsite.instruction()->while_condition()->parameter_instruction( 0)); } @@ -678,7 +686,7 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - for (const HloComputation* computation : module_->computations()) { + for (const HloComputation* computation : module_.computations()) { const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. @@ -779,9 +787,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { /* static */ StatusOr> HloDataflowAnalysis::Run( - HloModule* module, bool ssa_form, bool bitcast_defines_value) { - VLOG(1) << "HloDataflowAnalysis::Run on module " << module->name(); - XLA_VLOG_LINES(2, module->ToString()); + const HloModule& module, bool ssa_form, bool bitcast_defines_value) { + VLOG(1) << "HloDataflowAnalysis::Run on module " << module.name(); + XLA_VLOG_LINES(2, module.ToString()); auto dataflow_analysis = WrapUnique( new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value)); @@ -798,7 +806,7 @@ StatusOr> HloDataflowAnalysis::Run( // lookup is faster. std::vector> value_positions( dataflow_analysis->next_value_id_); - for (const HloComputation* computation : module->computations()) { + for (const HloComputation* computation : module.computations()) { for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : dataflow_analysis->GetInstructionValueSet(instruction)) { @@ -850,7 +858,7 @@ Status HloDataflowAnalysis::Verify() const { // For each value in each value set, verify that the value set's position // appears in the value's positions(). - for (const auto& computation : module_->computations()) { + for (const auto& computation : module_.computations()) { for (const auto& instruction : computation->instructions()) { for (const auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 89d318188f0855c7924836a51cfe98d531e08cb4..7b8a74b096ff48733717e78ada5bb56a28caed72 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -60,7 +60,7 @@ class HloDataflowAnalysis { // a new HLO value in the analysis. If false then Bitcast forwards the // value of its operand. static StatusOr> Run( - HloModule* module, bool ssa_form = false, + const HloModule& module, bool ssa_form = false, bool bitcast_defines_value = false); // Returns true if 'instruction' defines an HLO value at the given shape index @@ -119,7 +119,7 @@ class HloDataflowAnalysis { string ToString() const; protected: - HloDataflowAnalysis(HloModule* module, bool ssa_form, + HloDataflowAnalysis(const HloModule& module, bool ssa_form, bool bitcast_defines_value = false); // Returns a new HloValue defined at the given instruction and shape index. @@ -180,7 +180,7 @@ class HloDataflowAnalysis { // Verify various invariants of the dataflow analysis. Status Verify() const; - HloModule* const module_; + const HloModule& module_; const bool ssa_form_; const bool bitcast_defines_value_; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index e714b2567fd1b3eab607a19f0bb7e3288150dc64..7bf3a1a06045c79621d75b653bf42220705a69d4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -50,7 +50,7 @@ class HloDataflowAnalysisTest : public HloTestBase, bool bitcast_defines_value = false) { hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before dataflow analysis"); analysis_ = - HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value) + HloDataflowAnalysis::Run(*module_, ssa_form, bitcast_defines_value) .ConsumeValueOrDie(); return *analysis_; } diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 1e5f0f797a13fd7e7ce1cc934387a274a74153bc..fcd723af146e2227b8661b1a4993f1338f7de389 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -40,7 +40,7 @@ StatusOr HloDCE::Run(HloModule* module) { VLOG(2) << "Before dce:"; XLA_VLOG_LINES(2, module->ToString()); - for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* computation : module->MakeComputationPostOrder()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index 81212cda4266ec820230d0d84fc2a395edaf411e..afbfdac05e1b09abfe2555316dc13c8334dd6182 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -34,8 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" @@ -59,6 +57,12 @@ struct is_complex_t : public std::false_type {}; template <> struct is_complex_t : public std::true_type {}; +template +struct is_complex64_t : public std::false_type {}; + +template <> +struct is_complex64_t : public std::true_type {}; + template StatusOr> Compare(const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, @@ -250,17 +254,37 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value || - is_complex_t::value>::type* = nullptr> + typename std::enable_if::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { + ElementWiseUnaryOp(abs, [](NativeT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { + const Literal& operand_literal = + parent_->GetEvaluatedLiteralFor(abs->operand(0)); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[abs], + (ElementWiseUnaryOpImpl( + abs, [](NativeT elem_operand) { return std::abs(elem_operand); }, + operand_literal))); + + return Status::OK(); + } + Status HandleAbs(HloInstruction* abs) override { + // If the operand is of C64 type, the return type of abs will be F32. + // However, ElementwiseT would still be the return type, F32, and thus + // specifying the ElementwiseT explicitly as C64 is needed below. + if (abs->operand(0)->shape().element_type() == C64) { + return HandleAbs(abs); + } return HandleAbs(abs); } @@ -742,7 +766,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[shl], ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { - return lhs_elem << rhs_elem; + return IsShiftOutOfBounds(rhs_elem) ? 0 + : (lhs_elem << rhs_elem); })); return Status::OK(); } @@ -767,8 +792,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[shr], ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { - return static_cast(static_cast(lhs_elem) >> - rhs_elem); + SignedT lhs_signed = static_cast(lhs_elem); + if (IsShiftOutOfBounds(rhs_elem)) { + return lhs_signed < 0 ? static_cast(-1) : 0; + } else { + return lhs_signed >> rhs_elem; + } })); return Status::OK(); } @@ -794,6 +823,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN( parent_->evaluated_[shr], ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + // If shift amount is greater than the number of bits, then return 0. + if (IsShiftOutOfBounds(rhs_elem)) { + return static_cast(0); + } return static_cast(static_cast(lhs_elem) >> rhs_elem); })); @@ -1403,6 +1436,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; } + case F16: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], + MapImpl(map)); + break; + } case F32: { TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); break; @@ -2024,6 +2062,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return std::move(result); } + template + static bool IsShiftOutOfBounds(NativeT rhs) { + typedef typename std::make_unsigned::type UnsignedT; + UnsignedT lhs_size_unsigned = sizeof(NativeT) * CHAR_BIT; + UnsignedT rhs_unsigned = static_cast(rhs); + return rhs_unsigned >= lhs_size_unsigned; + } + HloEvaluator* parent_; }; // class HloEvaluator::TypedVisitor @@ -2041,9 +2087,7 @@ HloEvaluator::HloEvaluator() { }); typed_visitors_[S32] = MakeUnique>(this); typed_visitors_[S64] = MakeUnique>(this); - typed_visitors_[F16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: F16."); - }); + typed_visitors_[F16] = MakeUnique>(this); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); @@ -2427,6 +2471,54 @@ Status HloEvaluator::HandleCopy(HloInstruction* copy) { return Status::OK(); } +Status HloEvaluator::HandleCall(HloInstruction* call) { + auto* computation = call->to_apply(); + auto operands = call->operands(); + + std::vector arg_literals; + arg_literals.reserve(operands.size()); + for (auto operand : operands) { + const Literal& arg_literal = GetEvaluatedLiteralFor(operand); + arg_literals.push_back(&arg_literal); + } + + HloEvaluator embedded_evaluator; + std::unique_ptr result = + embedded_evaluator.Evaluate(*computation, arg_literals) + .ConsumeValueOrDie(); + + evaluated_[call] = std::move(result); + return Status::OK(); +} + +Status HloEvaluator::HandleConditional(HloInstruction* conditional) { + const auto& pred = GetEvaluatedLiteralFor(conditional->operand(0)); + const auto& true_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(1)); + const auto& false_computation_arg = + GetEvaluatedLiteralFor(conditional->operand(2)); + + auto* true_computation = conditional->true_computation(); + auto* false_computation = conditional->false_computation(); + + auto result = Literal::CreateFromShape(conditional->shape()); + HloEvaluator embedded_evaluator; + if (pred.Get({})) { + result = embedded_evaluator + .Evaluate(*true_computation, + {&true_computation_arg}) + .ConsumeValueOrDie(); + } else { + result = embedded_evaluator + .Evaluate(*false_computation, + {&false_computation_arg}) + .ConsumeValueOrDie(); + } + + evaluated_[conditional] = std::move(result); + return Status::OK(); +} + Status HloEvaluator::Preprocess(HloInstruction* hlo) { VLOG(2) << "About to visit HLO: " << hlo->ToString(); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 3b2b697e492a78a06a4e5ae6bf056ff8676f2ff5..fc8201163082576b0c1146da7bc14b468695cca8 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -153,6 +153,10 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + + Status HandleCall(HloInstruction* call) override; + private: // Returns the already-evaluated literal result for the instruction. // A Constant instruction is considered evaluated and its literal will be diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 44fcd36370dcd0cf77601aa1cd2b92810947bd5f..2861fec39ef0c92fdfbcee04584f9bd36d3cb4d8 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -940,6 +940,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kConcatenate: case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: @@ -988,6 +989,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 0e4437b73b5afb6519506f8afd2b962c5be2111b..a534d8ff063fef9336b0c87a5b88457384f694d2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -801,6 +801,22 @@ static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { return instruction; } +HloInstruction* HloInstruction::AddFusionOperand(HloInstruction* new_operand) { + CHECK_EQ(opcode(), HloOpcode::kFusion); + CHECK_EQ(operand_count(), + fused_instructions_computation()->parameter_instructions().size()); + const int64 param_no = operand_count(); + // Name the parameter after the instruction it represents in the outer + // (non-fusion) computation. + string param_name = StrCat(new_operand->name(), ".param_", param_no); + HloInstruction* fused_parameter = + fused_instructions_computation()->AddParameter( + HloInstruction::CreateParameter(param_no, new_operand->shape(), + param_name)); + AppendOperand(new_operand); + return fused_parameter; +} + void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -993,13 +1009,7 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // Clone's operand was not already an operand of the fusion // instruction. Add it as an operand and add a corresponding fused // parameter instruction. - int64 param_no = fused_parameters.size(); - // Name the parameter after the instruction it represents in the outer - // (non-fusion) computation. - string param_name = StrCat(operand->name(), ".param_", param_no); - fused_param = fused_instructions_computation()->AddParameter( - CreateParameter(param_no, operand->shape(), param_name)); - AppendOperand(operand); + fused_param = AddFusionOperand(operand); } TF_CHECK_OK(clone->ReplaceOperandWith(operand_num, fused_param)); } @@ -1084,6 +1094,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: + case HloOpcode::kHostCompute: return true; default: { // Check if any of the called computations has a side effect. @@ -1121,6 +1132,19 @@ bool HloInstruction::HasSideEffect() const { return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->channel_name_ = channel_name.ToString(); + instruction->cost_estimate_ns_ = cost_estimate_ns; + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateTuple( tensorflow::gtl::ArraySlice elements) { std::vector element_shapes; @@ -1131,6 +1155,40 @@ bool HloInstruction::HasSideEffect() const { return CreateVariadic(tuple_shape, HloOpcode::kTuple, elements); } +/* static */ std::unique_ptr HloInstruction::CreateGather( + const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + std::unique_ptr instruction = + WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); + instruction->AppendOperand(operand); + instruction->AppendOperand(gather_indices); + instruction->gather_dimension_numbers_ = + MakeUnique(gather_dim_numbers); + c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); + return instruction; +} + +/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim) { + GatherDimensionNumbers gather_dim_numbers; + for (int64 output_window_dim : output_window_dims) { + gather_dim_numbers.add_output_window_dims(output_window_dim); + } + for (int64 elided_window_dim : elided_window_dims) { + gather_dim_numbers.add_elided_window_dims(elided_window_dim); + } + for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { + gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); + } + + gather_dim_numbers.set_index_vector_dim(index_vector_dim); + return gather_dim_numbers; +} + std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, tensorflow::gtl::ArraySlice new_operands, @@ -1212,6 +1270,10 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); break; + case HloOpcode::kHostCompute: + clone = CreateHostCompute(shape, new_operands, channel_name_, + cost_estimate_ns_); + break; case HloOpcode::kConcatenate: clone = CreateConcatenate(shape, new_operands, dimensions(0)); break; @@ -1361,12 +1423,19 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( break; case HloOpcode::kRecv: CHECK_EQ(new_operands.size(), 0); - clone = CreateRecv(shape, channel_id()); + // The shape is a tuple, but CreateRecv() wants the raw data shape. + clone = + CreateRecv(ShapeUtil::GetTupleElementShape(shape, 0), channel_id()); break; case HloOpcode::kRecvDone: CHECK_EQ(new_operands.size(), 1); clone = CreateRecvDone(new_operands[0]); break; + case HloOpcode::kGather: + CHECK_EQ(new_operands.size(), 2); + clone = CreateGather(shape, new_operands[0], new_operands[1], + *gather_dimension_numbers_, gather_window_bounds_); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1710,6 +1779,11 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); + case HloOpcode::kGather: + return protobuf_util::ProtobufEquals(gather_dimension_numbers(), + other.gather_dimension_numbers()) && + gather_window_bounds() == other.gather_window_bounds(); + // FFT has various types & lengths. case HloOpcode::kFft: return fft_type() == other.fft_type() && @@ -1780,6 +1854,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kHostCompute: return false; } } @@ -1805,7 +1880,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) { Status HloInstruction::ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer) { - TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape())) + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); @@ -1828,8 +1904,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); - TF_RET_CHECK( - ShapeUtil::Compatible(old_operand->shape(), new_operand->shape())) + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) << old_operand->shape().ShortDebugString() << " is not compatible with " << new_operand->shape().ShortDebugString(); operands_[operand_num] = new_operand; @@ -2139,6 +2215,11 @@ std::vector HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } + if (gather_dimension_numbers_ != nullptr) { + extra.push_back(GatherDimensionNumbersToString()); + extra.push_back( + StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); + } if (opcode() == HloOpcode::kFft) { extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); @@ -2270,6 +2351,14 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } + if (gather_dimension_numbers_ != nullptr) { + *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; + } + if (opcode() == HloOpcode::kGather) { + for (int64 bound : gather_window_bounds()) { + proto.add_gather_window_bounds(bound); + } + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2564,6 +2653,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); + case HloOpcode::kHostCompute: + return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: @@ -2584,6 +2675,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSend(this); case HloOpcode::kSendDone: return visitor->HandleSendDone(this); + case HloOpcode::kGather: + return visitor->HandleGather(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3267,6 +3360,26 @@ string HloInstruction::DotDimensionNumbersToString() const { return Join(result, ", "); } +string HloInstruction::GatherDimensionNumbersToString() const { + CHECK_NE(gather_dimension_numbers_.get(), nullptr); + string output_window_dims = + StrCat("output_window_dims={", + Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); + string elided_window_dims = + StrCat("elided_window_dims={", + Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); + string gather_dims_to_operand_dims = StrCat( + "gather_dims_to_operand_dims={", + Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); + string index_vector_dim = StrCat( + "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); + + return Join>( + {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, + index_vector_dim}, + ", "); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3170746157fbcfa7d0a7eaba6d226d46691105f9..e4c86214c2014095b2e171ff10691e1221574cb7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -451,6 +451,12 @@ class HloInstruction { HloInstruction* true_computation_arg, HloComputation* true_computation, HloInstruction* false_computation_arg, HloComputation* false_computation); + static std::unique_ptr CreateGather( + const Shape& shape, HloInstruction* operand, + HloInstruction* gather_indices, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + // Creates a fusion instruction. A fusion instruction contains one or more // fused instructions forming an expression with a single root // "fused_root". Additional instructions can be added to the fusion @@ -475,6 +481,12 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice operands, tensorflow::StringPiece custom_call_target); + // Creates a HostCompute instruction, which records host-side control and + // data dependencies for use in instruction scheduling. + static std::unique_ptr CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( @@ -486,6 +498,13 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice dimensions); + // Creates an instance of GatherDimensionNumbers. + static GatherDimensionNumbers MakeGatherDimNumbers( + tensorflow::gtl::ArraySlice output_window_dims, + tensorflow::gtl::ArraySlice elided_window_dims, + tensorflow::gtl::ArraySlice gather_dims_to_operand_dims, + int64 index_vector_dim); + // Returns the opcode for this instruction. HloOpcode opcode() const { return opcode_; } @@ -767,6 +786,10 @@ class HloInstruction { // // (We express the default options using an overload rather than a default // param because gdb ignores default params, but does resolve overloads.) + // + // TODO(b/73348663): Make ToString() adaptive to the size of the string by + // default, backing off on providing full information for very large strings, + // or provide a different name for a ToString-like function that does that. string ToString() const { return ToString(HloPrintOptions()); } string ToString(const HloPrintOptions& options) const; @@ -802,6 +825,12 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kSend or HloOpcode::kRecv int64 channel_id() const { return channel_id_; } + // Returns the channel name associated with the instruction. The name is + // used to identify host Send/Recv operations. + // + // Precondition: opcode() == HloOpcode::kHostCompute + string channel_name() const { return channel_name_; } + // Returns feature_index field associated with the instruction. The index // represents the index of the feature dimension. // @@ -914,6 +943,9 @@ class HloInstruction { // Return true if this operator has a sharding assigned. bool has_sharding() const { return sharding_ != nullptr; } + // Adds a new operand the fusion instruction. + HloInstruction* AddFusionOperand(HloInstruction* new_operand); + // Merges the fused instructions from 'instruction_to_merge' into the // fused instruction set of 'this', updating operands as necessary. // @@ -1086,6 +1118,19 @@ class HloInstruction { // Returns the dump string of the dot dimension numbers. string DotDimensionNumbersToString() const; + const GatherDimensionNumbers& gather_dimension_numbers() const { + CHECK(gather_dimension_numbers_ != nullptr); + return *gather_dimension_numbers_; + } + + tensorflow::gtl::ArraySlice gather_window_bounds() const { + CHECK_EQ(opcode(), HloOpcode::kGather); + return gather_window_bounds_; + } + + // Returns the dump string of the gather dimension numbers. + string GatherDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1350,6 +1395,9 @@ class HloInstruction { // Describes the dimension numbers used for a dot. std::unique_ptr dot_dimension_numbers_; + std::unique_ptr gather_dimension_numbers_; + std::vector gather_window_bounds_; + // Describes FFT type for an FFT instruction. FftType fft_type_ = FftType::FFT; @@ -1388,6 +1436,12 @@ class HloInstruction { // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; + // Name to use for host send/recv channels, only present for kHostCompute. + string channel_name_; + + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_; + // Computations called by this instruction. std::vector called_computations_; diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 94e9bfe56eb445ec0b459a55342cd3cc4c6f68ef..f2980d309d01fdf3b3e601bc260a0ad0895b3064 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1271,5 +1271,77 @@ TEST_F(HloInstructionTest, Stringification) { "true_computation=%TransposeDot, false_computation=%TransposeDot"); } +TEST_F(HloInstructionTest, StringifyGather_0) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=4, window_bounds={30,29,28,27,26}"); +} + +TEST_F(HloInstructionTest, StringifyGather_1) { + Shape input_tensor_shape = ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + Shape gather_indices_tensor_shape = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + Shape gather_result_shape = + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}); + + HloComputation::Builder builder("Gather"); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_tensor_shape, "input_tensor")); + HloInstruction* gather_indices = + builder.AddInstruction(HloInstruction::CreateParameter( + 1, gather_indices_tensor_shape, "gather_indices")); + + HloInstruction* gather_instruction = + builder.AddInstruction(HloInstruction::CreateGather( + gather_result_shape, input, gather_indices, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + HloModule module(TestName()); + module.AddEntryComputation(builder.Build()); + + EXPECT_EQ(gather_instruction->ToString(), + "%gather = f32[10,9,7,6,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} " + "gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, " + "s64[10,9,5,7,6]{4,3,2,1,0} %gather_indices), " + "output_window_dims={4,5,6,7,8}, elided_window_dims={}, " + "gather_dims_to_operand_dims={0,1,2,3,4}, " + "index_vector_dim=2, window_bounds={30,29,28,27,26}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 60270b0595dcfca8f1fcea5ab0914428880f35b5..cb2fe9f874012a51e1e6cbd1dd086dbb26994bde 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -145,6 +145,21 @@ void HloModule::ReplaceComputations( } break; } + case HloOpcode::kConditional: { + HloComputation* new_true_computation = + tensorflow::gtl::FindWithDefault( + replacements, instruction->true_computation(), nullptr); + if (new_true_computation != nullptr) { + instruction->set_true_computation(new_true_computation); + } + HloComputation* new_false_computation = + tensorflow::gtl::FindWithDefault( + replacements, instruction->false_computation(), nullptr); + if (new_false_computation != nullptr) { + instruction->set_false_computation(new_false_computation); + } + break; + } case HloOpcode::kSelectAndScatter: { HloComputation* new_select = tensorflow::gtl::FindWithDefault( replacements, instruction->select(), nullptr); @@ -563,6 +578,18 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { return module; } +HloComputation* HloModule::DeepCloneComputation(HloComputation* computation) { + HloComputation* clone = AddEmbeddedComputation(computation->Clone("", this)); + TF_CHECK_OK( + clone->root_instruction()->Accept([this](HloInstruction* instruction) { + instruction->ReplaceCalledComputations([this](HloComputation* callee) { + return DeepCloneComputation(callee); + }); + return Status::OK(); + })); + return clone; +} + uint64 HloModule::RandomNew64() const { tensorflow::mutex_lock l(rng_mutex_); return rng_(); diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 4bfe8d89ce0a285de6d05d4867aaa6b266d78d12..06d92f94fd6f62162b22575e9cc341f2906cd0db 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -85,6 +85,10 @@ class HloModule { // Returns a deep copy of this module including all computations. std::unique_ptr Clone(const string& suffix = "clone") const; + // Performs a deep clone of the computation, by recursively cloning all + // the called computations as well. + HloComputation* DeepCloneComputation(HloComputation* computation); + // Return a pointer to the entry computation of the module.. const HloComputation* entry_computation() const { CHECK_NE(nullptr, entry_computation_); diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index a5ee895e48448fbb8fa3879dc1b6764c1f9f6966..d3c1fae592bb465609ffbde2d0262e2600912e63 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -67,6 +67,15 @@ class HloModuleConfig { bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } + // Sets/returns whether this is a "host module". Host modules are used to + // record the data- and control-flow dependencies of host side computation + // that communicates with compiled code. They are used for analysis and + // scheduling purposes, but no code is generated. + bool is_host_module() const { return is_host_module_; } + void set_is_host_module(bool is_host_module) { + is_host_module_ = is_host_module; + } + // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } @@ -104,6 +113,9 @@ class HloModuleConfig { // Whether to enable HLO-level profiling. bool hlo_profiling_enabled_ = false; + // Whether this is a 'host module'. + bool is_host_module_ = false; + // Module/graph-level seed handle. uint64 seed_ = 0; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 3d64523a79fc50638fdf378b5d521a5cd4482b90..af24604c39b554f146793594958f373999844b4c 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -76,9 +76,11 @@ namespace xla { V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGather, "gather") \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIsFinite, "is-finite") \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 68e3c9618c1fe9daacb0aee3ee98862c8b9e4bc4..1b24d8da9e832e6847cb6f405e15af3c455f695a 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -186,6 +186,22 @@ bool HloOrdering::UseIsBeforeValueDefinition( } } + if (use.instruction->opcode() == HloOpcode::kConditional) { + const HloInstruction* conditional = use.instruction; + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + conditional->true_computation())) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in TRUE computation"; + return true; + } + if (call_graph_->InstructionIsNestedIn(value.defining_instruction(), + conditional->false_computation())) { + VLOG(4) << " use is conditional " << use.instruction->name() + << " and def is in FALSE computation"; + return true; + } + } + VLOG(4) << " use is not before value"; return false; } diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index aba66114de649ce7667ae77174e9c4073b010b90..a989fce63234cb860d08c48b02462e96bec879bc 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -262,8 +262,8 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { scalar_shape, HloOpcode::kAdd, constant, xla_while)); module->AddEntryComputation(builder.Build()); - TF_ASSERT_OK_AND_ASSIGN( - auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true)); + TF_ASSERT_OK_AND_ASSIGN(auto dataflow, + HloDataflowAnalysis::Run(*module, /*ssa_form=*/true)); DependencyHloOrdering ordering(module.get()); // Init value is defined before the while, but live range is not before the diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index c6b4dc0368d92fd477decdfb38045f74f8696803..98b8d34be1f331aaeac94e952deeae1e76379861 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -60,6 +60,7 @@ bool IsRematerializable(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kConstant: + case HloOpcode::kConditional: case HloOpcode::kCrossReplicaSum: case HloOpcode::kCustomCall: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 5f5a930dad002c215a5332286ade97ef19cc67af..f6e33403f538bd8492b04c34d46a458f7f06cc06 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -15,7 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" -#include +#include #include #include @@ -101,7 +101,7 @@ class ListScheduler { // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. for (auto* instruction : computation.instructions()) { - std::unordered_set instr_uses; + tensorflow::gtl::FlatSet instr_uses; for (auto* operand : instruction->operands()) { for (const LogicalBuffer* buffer : points_to_analysis.GetBuffersDefinedByInstruction(operand)) { @@ -151,7 +151,7 @@ class ListScheduler { int64 bytes_defined; // For each buffer B used by this instruction, we keep a pair (B, U), where - // U is the number of uses of B that have not yet been scheduled. This pair + // U is the number of uses of B that have not yet been scheduled. This pair // is a pointer into the unscheduled_use_count_ map, so it gets updated for // free when we update counts in the map. std::vector*> @@ -206,7 +206,8 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. - std::unordered_map unscheduled_pred_count; + tensorflow::gtl::FlatMap + unscheduled_pred_count; for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. @@ -218,33 +219,48 @@ class ListScheduler { } } - auto priority_comparator = [this](const ReadyListEntry& lhs, - const ReadyListEntry& rhs) { - return GetPriority(lhs) < GetPriority(rhs); + // Use a multimap to sort ReadyListEntry according to their priority. + std::multimap ready_queue; + + // Map of ready instructions to their iterators in ready_queue. + tensorflow::gtl::FlatMap::iterator> + ready_instructions; + + auto add_to_ready_queue = [&](HloInstruction* inst) { + auto entry = MakeReadyListEntry(inst); + auto it = ready_queue.emplace(GetPriority(entry), std::move(entry)); + ready_instructions[inst] = it; }; - std::priority_queue, - decltype(priority_comparator)> - ready_queue(priority_comparator); + for (auto* instruction : computation_.instructions()) { // Instruction with no operands or control predecessors will // not be in the map. if (unscheduled_pred_count.count(instruction) == 0) { - ready_queue.emplace(MakeReadyListEntry(instruction)); + add_to_ready_queue(instruction); } } while (!ready_queue.empty()) { // Remove the selected instruction from the ready list and add it to the // schedule. - const HloInstruction* best = ready_queue.top().instruction; - ready_queue.pop(); + auto best_it = ready_queue.end(); + --best_it; + const HloInstruction* best = best_it->second.instruction; + ready_queue.erase(best_it); + ready_instructions.erase(best); schedule.push_back(best); scheduled_instructions_.insert(best); + bool adjust_ready_queue = false; // Update the unscheduled uses of the logical buffers. for (const LogicalBuffer* buffer : buffer_uses_.at(best)) { - CHECK_GT(unscheduled_use_count_.at(buffer), 0); - --unscheduled_use_count_[buffer]; + int64& count = unscheduled_use_count_[buffer]; + CHECK_GT(count, 0); + --count; + if (count == 1) { + adjust_ready_queue = true; + } } // Add new instructions to ready list. @@ -252,7 +268,7 @@ class ListScheduler { int64 pred_count = --unscheduled_pred_count.at(inst); CHECK_GE(pred_count, 0); if (pred_count == 0) { - ready_queue.emplace(MakeReadyListEntry(inst)); + add_to_ready_queue(inst); } }; // TODO(b/34466113): Replace this and above with successors() or @@ -263,6 +279,31 @@ class ListScheduler { for (HloInstruction* succ : best->control_successors()) { update_pred_count(succ); } + // The unscheduled use count for a buffer has changed to 1, so the + // priorities of some ready instructions may go up. We update them in the + // ready queue, so that they can appear earlier. + if (adjust_ready_queue) { + for (HloInstruction* operand : best->operands()) { + for (HloInstruction* operand_user : operand->users()) { + auto ready_instructions_it = ready_instructions.find(operand_user); + if (ready_instructions_it == ready_instructions.end()) { + continue; + } + auto ready_queue_it = ready_instructions_it->second; + auto& entry = ready_queue_it->second; + Priority new_priority = GetPriority(entry); + if (new_priority == ready_queue_it->first) { + continue; + } + // Create a new entry in ready_queue, then update + // ready_instructions[operand_user] to refer to the new entry. + ready_instructions_it->second = + ready_queue.emplace(new_priority, std::move(entry)); + // Remove the old entry in ready_queue. + ready_queue.erase(ready_queue_it); + } + } + } } CHECK_EQ(schedule.size(), computation_.instruction_count()); CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); @@ -275,15 +316,17 @@ class ListScheduler { const LogicalBuffer::SizeFunction& size_function_; // A map containing the LogicalBuffers that each instruction uses. - std::unordered_map> + tensorflow::gtl::FlatMap> buffer_uses_; // A map containing the count of unscheduled HLOs which using a particular - // LogicalBuffer. We rely on iterator stability in this map. + // LogicalBuffer. We rely on iterator stability in this map, and that the map + // entries are std::pair's. std::unordered_map unscheduled_use_count_; // Set of instructions which have been scheduled. - std::unordered_set scheduled_instructions_; + tensorflow::gtl::FlatSet scheduled_instructions_; }; int64 SumLogicalBufferSizes( diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index 447c2446668253c932b44b51b2db22bfd47f9957..afe79c9f17befdcb2812c0a08b205f21b0715b19 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -183,6 +183,10 @@ Status HloSharding::ValidateTuple(const Shape& shape, int64 num_devices) const { // shape tree. ShapeTree shape_tree = GetAsShapeTree(shape); for (const auto& index_to_sharding : shape_tree.leaves()) { + if (index_to_sharding.first.empty()) { + // An empty tuple has a ShapeTree with a single leaf at the empty index. + continue; + } Status status = index_to_sharding.second.ValidateNonTuple( ShapeUtil::GetSubshape(shape, index_to_sharding.first), num_devices); if (!status.ok()) { @@ -222,7 +226,7 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, Status status = Status::OK(); std::set seen_cores; tile_assignment_.Each( - [&](tensorflow::gtl::ArraySlice indices, uint32 core) { + [&](tensorflow::gtl::ArraySlice indices, int32 core) { // Don't overwrite a bad status, so we report the first error. if (status.ok()) { if (core >= num_devices) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 04d4656546684063d3d6532e443ad7995c6ea8db..b1fd068115e1d104a11d880675ef84e07d6d5602 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -123,6 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { return CheckShape(outfeed, ShapeUtil::MakeNil()); } +Status ShapeVerifier::HandleHostCompute(HloInstruction*) { + return tensorflow::Status::OK(); +} + Status ShapeVerifier::HandleRng(HloInstruction*) { return tensorflow::Status::OK(); } @@ -164,6 +170,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { // HLO broadcast has no exact analog at the proto level so there is no // ShapeInference method. Check the output shape explicitly. const Shape& operand_shape = broadcast->operand(0)->shape(); + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == broadcast->dimensions().size()); for (int64 operand_dimension = 0; @@ -178,6 +186,8 @@ Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { } Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == ShapeUtil::ElementsIn(reshape->operand(0)->shape())); return tensorflow::Status::OK(); @@ -359,13 +369,130 @@ Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { batch_norm_grad->feature_index())); } +namespace { + +// Checks that the instruction does not have mixed precision floating point +// inputs. +Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { + switch (instruction->opcode()) { + // White list the following opcodes for mixed-precision check, because they + // involve data pass through or grouping via tuples, where the precisions + // of buffers can be different. + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kCustomCall: + case HloOpcode::kFusion: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReducePrecision: + case HloOpcode::kSelect: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + break; + default: { + PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID; + for (auto operand : instruction->operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::ElementIsFloating(subshape)) { + return Status::OK(); + } + if (fp_type == PRIMITIVE_TYPE_INVALID) { + fp_type = subshape.element_type(); + } else if (fp_type != subshape.element_type()) { + return FailedPrecondition( + "Seen floating point types of different precisions in " + "%s, but mixed precision is disallowed.", + instruction->ToString().c_str()); + } + return Status::OK(); + })); + } + } + } + return Status::OK(); +} + +} // namespace + +Status ShapeVerifier::HandleGather(HloInstruction* gather) { + return CheckShape( + gather, + ShapeInference::InferGatherShape( + gather->operand(0)->shape(), gather->operand(1)->shape(), + gather->gather_dimension_numbers(), gather->gather_window_bounds())); +} + Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const Shape& expected_shape) { - if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { + const Shape& inferred_shape) { + // If allow_mixed_precision_ is false, check if there are operands with + // different precisions. We need this check because ShapeInference allows + // mixed precision inputs. + if (!allow_mixed_precision_) { + TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); + } + + // Check if the output shape matches the expected shape. + bool compatible; + // We treat BF16 and F32 as compatible types if mixed precision is allowed, + // but only when the instruction defines the BF16/F32 buffer. + switch (instruction->opcode()) { + case HloOpcode::kSelect: + if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { + // Select only defines the top-level buffer, which in this case is the + // tuple, so we cannot allow mixed precision. + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } + break; + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed + // precision is disallowed. + case HloOpcode::kConstant: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kWhile: + // The above opcodes should match the expected shapes exactly. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); + break; + default: + if (allow_mixed_precision_) { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } else { + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } + } + if (!compatible) { return InvalidArgument( "Expected instruction to have shape compatible with %s, actual " "shape is %s:\n%s", - ShapeUtil::HumanString(expected_shape).c_str(), + ShapeUtil::HumanString(inferred_shape).c_str(), ShapeUtil::HumanString(instruction->shape()).c_str(), instruction->ToString().c_str()); } @@ -373,14 +500,14 @@ Status ShapeVerifier::CheckShape(const HloInstruction* instruction, } Status ShapeVerifier::CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status) { - if (!expected_shape_status.ok()) { - Status s = expected_shape_status.status(); + const StatusOr& inferred_shape_status) { + if (!inferred_shape_status.ok()) { + Status s = inferred_shape_status.status(); tensorflow::errors::AppendToMessage(&s, ", for instruction ", instruction->ToString()); return s; } - return CheckShape(instruction, expected_shape_status.ValueOrDie()); + return CheckShape(instruction, inferred_shape_status.ValueOrDie()); } Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 26d53dec1e52f0bf19d6a8af998c56db8a850518..1dd7ec3c51e18dcfe89bd478de87798ba3858119 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -27,6 +27,10 @@ namespace xla { // TODO(b/26024837): Check output shape for all instruction types. class ShapeVerifier : public DfsHloVisitor { public: + explicit ShapeVerifier() : allow_mixed_precision_(false) {} + explicit ShapeVerifier(bool allow_mixed_precision) + : allow_mixed_precision_(allow_mixed_precision) {} + Status HandleElementwiseUnary(HloInstruction* hlo) override; Status HandleElementwiseBinary(HloInstruction* hlo) override; Status HandleClamp(HloInstruction* clamp) override; @@ -56,6 +60,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; + Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( @@ -75,20 +80,21 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleBatchNormInference( HloInstruction* batch_norm_inference) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + Status HandleGather(HloInstruction* gather) override; Status FinishVisit(HloInstruction*) override { return tensorflow::Status::OK(); } protected: - // Check the instruction's shape against the given expected shape and return - // an appropriate error if there is a mismatch. + // Check the instruction's shape against the shape given by ShapeInference + // and return an appropriate error if there is a mismatch. Status CheckShape(const HloInstruction* instruction, - const Shape& expected_shape); + const Shape& inferred_shape); // Overload which takes a StatusOr to reduce boilerplate in the caller. Status CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status); + const StatusOr& inferred_shape_status); // Check a unary (binary, etc) instruction's shape against the inferred shape. Status CheckUnaryShape(const HloInstruction* instruction); @@ -99,19 +105,32 @@ class ShapeVerifier : public DfsHloVisitor { // Checks if the given two instructions shares the same channel id. Status CheckSameChannel(const HloInstruction* instr1, const HloInstruction* instr2); + + private: + // Whether the inputs and output of an instruction can contain both F32s and + // BF16s. Tuples that include both F32s and BF16s are allowed regardless of + // this flag. + bool allow_mixed_precision_; }; // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: + using ShapeVerifierFactory = std::function()>; + // Uses standard shape inference. explicit HloVerifier() - : shape_verifier_factory_([] { return MakeUnique(); }) {} + : shape_verifier_factory_( + [] { return MakeUnique(false); }) {} + + explicit HloVerifier(bool allow_mixed_precision) + : shape_verifier_factory_([allow_mixed_precision] { + return MakeUnique(allow_mixed_precision); + }) {} // Uses custom shape verification. - explicit HloVerifier( - std::function()> shape_verifier_factory) + explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) : shape_verifier_factory_(std::move(shape_verifier_factory)) {} ~HloVerifier() override = default; @@ -129,7 +148,7 @@ class HloVerifier : public HloPassInterface { // expectations. This is a factory function because ShapeVerifier, Note that // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object // for each run of the verifier. - std::function()> shape_verifier_factory_; + ShapeVerifierFactory shape_verifier_factory_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 90e1f0acdc4cdeda280dabaab2df66b181d0f407..f494748e17fc2d0de74dec67f7414d4791f76a07 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -102,6 +102,8 @@ namespace xla { case HloOpcode::kExp: case HloOpcode::kFft: case HloOpcode::kFusion: + case HloOpcode::kGather: + case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kMap: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 0cb9b5d8107cd8bf468b07d5fe2a22930d9e8b8c..883063d0f075f5b0d79edc01bcd27a7c579272f4 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -93,7 +93,7 @@ StatusOr> InterpreterExecutable::ExecuteOnStream( TF_ASSIGN_OR_RETURN(std::unique_ptr result, transfer_manager->AllocateShapedBuffer( result_literal->shape(), run_options->allocator(), - run_options->device_ordinal())); + executor->device_ordinal())); TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( executor, *result_literal, *result)); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index fce135ef61a7868386b869def1a79167c428d928..4929300f7d30ef6fa6c9e128a781e7780f54a520 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -53,6 +53,83 @@ limitations under the License. namespace xla { +// For now moving only one API here, but we should have a single top level +// anonymous namespace, instead of three or four spread all over this file. +namespace { + +// Creates and returns a copy of the given instruction with a different +// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple +// instruction producing the copy is returned. +StatusOr CreateCopyWithNewLayout( + const Shape& shape_with_layout, HloInstruction* instruction) { + TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); + DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) + << ShapeUtil::HumanString(shape_with_layout) << " " + << ShapeUtil::HumanString(instruction->shape()) + << " instruction: " << instruction->ToString(); + + if (ShapeUtil::IsTuple(instruction->shape())) { + // Deep-copy tuples. + std::vector element_copies; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); + ++i) { + HloInstruction* gte = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); + + // Recurse to copy each elements. + TF_ASSIGN_OR_RETURN( + HloInstruction * element_copy, + CreateCopyWithNewLayout( + ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); + element_copies.push_back(element_copy); + } + // Gather element copies into a tuple with a new Tuple instruction. + HloInstruction* tuple_copy = instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(element_copies)); + LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, tuple_copy->mutable_shape())); + return tuple_copy; + } else if (ShapeUtil::IsArray(instruction->shape())) { + HloInstruction* copy = + instruction->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction->shape(), HloOpcode::kCopy, instruction)); + LayoutUtil::ClearLayout(copy->mutable_shape()); + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + shape_with_layout, copy->mutable_shape())); + + return copy; + } else { + return FailedPrecondition( + "Can only copy array and tuple shaped instructions"); + } +} + +// Creates a copy of the given operand if the operand's layout does not match +// the given layout. This copy replaces the use in the given instruction. Tuple +// operands will be deep-copied. +Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, + HloInstruction* instruction, + int64 operand_no) { + HloInstruction* operand = instruction->mutable_operand(operand_no); + TF_RET_CHECK(operand_layout.LayoutIsSet()); + TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); + + if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { + // Operand layout already matches our constraint. Nothing to do. + return Status::OK(); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, + CreateCopyWithNewLayout(operand_layout.shape(), operand)); + + return instruction->ReplaceOperandWith(operand_no, operand_copy); +} + +} // namespace + std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint) { out << constraint.ToString(); @@ -115,17 +192,34 @@ LayoutConstraints::LayoutConstraints( } } +PointsToSet::BufferSet* LayoutConstraints::GetBufferSet( + const HloInstruction* instruction) const { + auto it = buffer_sets_cache_.find(instruction); + if (it != buffer_sets_cache_.end()) { + return it->second.get(); + } + auto& buffer_set = + buffer_sets_cache_ + .emplace(instruction, MakeUnique()) + .first->second; + const auto& points_to_set = points_to_analysis_.GetPointsToSet(instruction); + points_to_set.ForEachElement( + [&buffer_set](const ShapeIndex& /*index*/, + const PointsToSet::BufferList& buffers) { + buffer_set->insert(buffers.begin(), buffers.end()); + }); + return buffer_set.get(); +} + bool LayoutConstraints::OperandBufferForwarded( const HloInstruction* instruction, int64 operand_no) const { // The operand is potentially forwarded if the intersection of points-to sets // of the operand and the instruction is non-empty. - auto output_buffers = - points_to_analysis_.GetPointsToSet(instruction).CreateFlattenedSet(); - auto operand_buffers = - points_to_analysis_.GetPointsToSet(instruction->operand(operand_no)) - .CreateFlattenedSet(); - for (const LogicalBuffer* output_buffer : output_buffers) { - if (operand_buffers.count(output_buffer) > 0) { + PointsToSet::BufferSet* output_buffers = GetBufferSet(instruction); + PointsToSet::BufferSet* operand_buffers = + GetBufferSet(instruction->operand(operand_no)); + for (const LogicalBuffer* output_buffer : *output_buffers) { + if (operand_buffers->count(output_buffer) > 0) { return true; } } @@ -512,6 +606,36 @@ Status LayoutAssignment::AddMandatoryConstraints( body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( body_layout.result_shape(), instruction, 0)); + } else if (instruction->opcode() == HloOpcode::kConditional) { + // The layout of the true and false computations must match, and must + // be the layout of the kConditional instruction. + TF_RET_CHECK(instruction->operand_count() == 3); + + HloComputation* true_computation = instruction->true_computation(); + HloComputation* false_computation = instruction->false_computation(); + const HloInstruction* true_operand = instruction->operand(1); + const HloInstruction* false_operand = instruction->operand(2); + + TF_RET_CHECK(true_computation->num_parameters() == 1); + TF_RET_CHECK(false_computation->num_parameters() == 1); + ComputationLayout& true_computation_layout = + FindOrDie(computation_layouts_, true_computation); + ComputationLayout& false_computation_layout = + FindOrDie(computation_layouts_, false_computation); + + DCHECK(ShapeUtil::Compatible(true_operand->shape(), + true_computation_layout.parameter_shape(0))); + DCHECK(ShapeUtil::Compatible( + false_operand->shape(), false_computation_layout.parameter_shape(0))); + + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + true_computation_layout.result_shape(), instruction)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + true_computation_layout.parameter_shape(0), instruction, 1, + /*mandatory=*/true)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + false_computation_layout.parameter_shape(0), instruction, 2, + /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { if (!CustomCallRequiresMajorFirstLayout(instruction)) { continue; @@ -598,6 +722,33 @@ Status CheckWhileLayout(HloInstruction* while_inst, return Status::OK(); } +Status CheckConditionalLayout( + HloInstruction* instruction, + const ComputationLayout& true_computation_layout, + const ComputationLayout& false_computation_layout) { + HloComputation* true_computation = instruction->true_computation(); + HloComputation* false_computation = instruction->false_computation(); + const HloInstruction* true_operand = instruction->operand(1); + const HloInstruction* false_operand = instruction->operand(2); + + TF_RET_CHECK(true_computation_layout.result_layout() == + false_computation_layout.result_layout()); + TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK(true_computation_layout.result_layout().MatchesLayoutInShape( + true_computation->root_instruction()->shape())); + TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( + instruction->shape())); + TF_RET_CHECK(false_computation_layout.result_layout().MatchesLayoutInShape( + false_computation->root_instruction()->shape())); + TF_RET_CHECK(true_computation_layout.parameter_layout(0).MatchesLayoutInShape( + true_operand->shape())); + TF_RET_CHECK( + false_computation_layout.parameter_layout(0).MatchesLayoutInShape( + false_operand->shape())); + return Status::OK(); +} + // Fusion parameters must match the layout of the fusion instructions operands, // and the root of the fusion expression must match the layout of the fusion // instruction. @@ -710,6 +861,13 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->while_condition()), FindOrDie(computation_layouts_, instruction->while_body()))); break; + case HloOpcode::kConditional: + TF_RETURN_IF_ERROR(CheckConditionalLayout( + instruction, + FindOrDie(computation_layouts_, instruction->true_computation()), + FindOrDie(computation_layouts_, + instruction->false_computation()))); + break; default: break; } @@ -1165,77 +1323,6 @@ StatusOr InferArrayLayout( return *first_buffer_layout; } -// Creates and returns a copy of the given instruction with a different -// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple -// instruction producing the copy is returned. -StatusOr CreateCopyWithNewLayout( - const Shape& shape_with_layout, HloInstruction* instruction) { - TF_RET_CHECK(LayoutUtil::HasLayout(shape_with_layout)); - DCHECK(ShapeUtil::Compatible(shape_with_layout, instruction->shape())) - << ShapeUtil::HumanString(shape_with_layout) << " " - << ShapeUtil::HumanString(instruction->shape()) - << " instruction: " << instruction->ToString(); - - if (ShapeUtil::IsTuple(instruction->shape())) { - // Deep-copy tuples. - std::vector element_copies; - for (int64 i = 0; i < ShapeUtil::TupleElementCount(instruction->shape()); - ++i) { - HloInstruction* gte = instruction->parent()->AddInstruction( - HloInstruction::CreateGetTupleElement( - ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, - i)); - - // Recurse to copy each elements. - TF_ASSIGN_OR_RETURN( - HloInstruction * element_copy, - CreateCopyWithNewLayout( - ShapeUtil::GetSubshape(shape_with_layout, {i}), gte)); - element_copies.push_back(element_copy); - } - // Gather element copies into a tuple with a new Tuple instruction. - HloInstruction* tuple_copy = instruction->parent()->AddInstruction( - HloInstruction::CreateTuple(element_copies)); - LayoutUtil::ClearLayout(tuple_copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, tuple_copy->mutable_shape())); - return tuple_copy; - } else if (ShapeUtil::IsArray(instruction->shape())) { - HloInstruction* copy = - instruction->parent()->AddInstruction(HloInstruction::CreateUnary( - instruction->shape(), HloOpcode::kCopy, instruction)); - LayoutUtil::ClearLayout(copy->mutable_shape()); - TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( - shape_with_layout, copy->mutable_shape())); - - return copy; - } else { - return FailedPrecondition( - "Can only copy array and tuple shaped instructions"); - } -} - -// Creates a copy of the given operand if the operand's layout does not match -// the given layout. This copy replaces the use in the given instruction. Tuple -// operands will be deep-copied. -Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, - HloInstruction* instruction, - int64 operand_no) { - HloInstruction* operand = instruction->mutable_operand(operand_no); - TF_RET_CHECK(operand_layout.LayoutIsSet()); - TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape())); - - if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) { - // Operand layout already matches our constraint. Nothing to do. - return Status::OK(); - } - - TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy, - CreateCopyWithNewLayout(operand_layout.shape(), operand)); - - return instruction->ReplaceOperandWith(operand_no, operand_copy); -} - // For fusion instructions, set the layout of each fused parameter instruction // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 29018584487cabfd740d7914625c2a50f552d6ff..7126cb50cf168241979178c9e1077051cc935e53 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/platform/types.h" namespace xla { @@ -199,6 +200,11 @@ class LayoutConstraints { string ToString() const; private: + // Find a bufferset in the bufferset cache. This is useful since we can + // currently create the flattened buffer set for the same instruction many + // times, which is often slow. + PointsToSet::BufferSet* GetBufferSet(const HloInstruction* instruction) const; + // The set of BufferLayoutConstraints applied to the computation. std::unordered_map buffer_constraints_; @@ -221,6 +227,10 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; + mutable tensorflow::gtl::FlatMap> + buffer_sets_cache_; + HloComputation* computation_; }; diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index e269a13459f1146f1d2952870399827d9e705e38..62feb7c1e9da0b3ecb9c21b876d86935775531d7 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -590,6 +590,85 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { transpose->shape(), {2, 3, 0, 1})); } +// TransposeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, TransposeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = builder.AddInstruction( + HloInstruction::CreateTranspose(input_shape, param, {0, 2, 1})); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH(ShapeUtil::TransposeIsBitcast(hlo->operand(0)->shape(), + hlo->shape(), hlo->dimensions()), + "LayoutUtil::HasLayout"); +} + +// ReshapeIsBitcast shouldn't be called without layout information. +TEST_F(LayoutAssignmentTest, ReshapeIsBitcastFail) { + auto builder = HloComputation::Builder(TestName()); + Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2}); + Shape input_shape_with_layout(input_shape); + *input_shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout({2, 1, 0}); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape_with_layout, "param")); + auto hlo = + builder.AddInstruction(HloInstruction::CreateReshape(input_shape, param)); + // Clear the default layout assigned to the instruction. + LayoutUtil::ClearLayout(hlo->mutable_shape()); + EXPECT_DEATH( + ShapeUtil::ReshapeIsBitcast(hlo->operand(0)->shape(), hlo->shape()), + "LayoutUtil::HasLayout"); +} + +// Check that the computation below doesn't crash the compiler. +// +// Within a fusion computation, only the parameters and result get assigned a +// layout. When we run the algebraic simplifier on this computation post layout +// assignment, it should not call TransposeIsBitcast on the `transpose` node +// inside the fusion computation as TransposeIsBitcast checks both input_shape +// and output_shape have layouts. +TEST_F(LayoutAssignmentTest, TransposeWithinFusionDoesNotCrash) { + const char* module_str = R"( + HloModule test_module + + fused_computation { + param_1 = f32[2,2,2]{2,1,0} parameter(1) + transpose = f32[2,2,2]{2,1,0} transpose(param_1), dimensions={0,2,1} + reduce_1 = f32[] parameter(0) + broadcast_1 = f32[2,2,2]{2,1,0} broadcast(reduce_1), dimensions={} + ROOT divide_1 = f32[2,2,2]{2,1,0} divide(transpose, broadcast_1) + } + + ENTRY entry_computation { + fusion.1 = f32[2,2,2]{2,1,0} parameter(1) + reduce.1 = f32[] parameter(0) + fusion.2 = f32[2,2,2]{2,1,0} fusion(reduce.1, fusion.1), kind=kLoop, calls=fused_computation + ROOT tuple.1 = (f32[2,2,2]{2,1,0}) tuple(fusion.2) + } + )"; + + auto module = tools::Parse(module_str).ValueOrDie(); + + module = + backend() + .compiler() + ->RunHloPasses(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .ConsumeValueOrDie(); + + EXPECT_EQ( + ::tensorflow::Status::OK(), + backend() + .compiler() + ->RunBackend(std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr) + .status()); +} + // A GTE inside of a fusion node inherits the layout of its operand (which // should, if we keep following operands, eventually be a parameter). TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { @@ -629,33 +708,92 @@ TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { LayoutUtil::MakeLayout({2, 1, 0})); AssignLayouts(module.get(), &computation_layout); - HloComputation* fused_computation = *std::find_if( - module->computations().begin(), module->computations().end(), - [](const HloComputation* c) { return c->name() == "fused_computation"; }); - - auto fused_instr = [&](const string& name) { - auto it = std::find_if( - fused_computation->instructions().begin(), - fused_computation->instructions().end(), - [&](const HloInstruction* i) { return i->name() == name; }); - CHECK(it != fused_computation->instructions().end()); - return *it; + auto layout_of = [&](tensorflow::StringPiece name) { + return FindInstruction(module.get(), name) + ->shape() + .layout() + .minor_to_major(); }; - EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(), - ElementsAre(0, 1, 2)); - EXPECT_THAT( - fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(), - ElementsAre(1, 2, 0)); - EXPECT_THAT( - fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(), - ElementsAre(2, 0, 1)); - EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(), + EXPECT_THAT(layout_of("gte0"), ElementsAre(0, 1, 2)); + EXPECT_THAT(layout_of("gte1a"), ElementsAre(1, 2, 0)); + EXPECT_THAT(layout_of("gte1b"), ElementsAre(2, 0, 1)); + EXPECT_THAT(layout_of("fresult"), ElementsAre(2, 1, 0)); + EXPECT_THAT(FindInstruction(module.get(), "gte1") + ->shape() + .tuple_shapes(0) + .layout() + .minor_to_major(), ElementsAre(1, 2, 0)); - EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(), + EXPECT_THAT(FindInstruction(module.get(), "gte1") + ->shape() + .tuple_shapes(1) + .layout() + .minor_to_major(), ElementsAre(2, 0, 1)); - EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(), - ElementsAre(2, 1, 0)); +} + +TEST_F(LayoutAssignmentTest, ConditionalAsymmetricLayout) { + auto builder = HloComputation::Builder(TestName()); + auto module = CreateNewModule(); + Shape shape = ShapeUtil::MakeShape(F32, {128, 8}); + Shape tshape = ShapeUtil::MakeTupleShape({shape, shape}); + Shape result_tshape = ShapeUtil::MakeTupleShape({shape}); + + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, shape, "param1")); + auto pred = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(PRED, {}), "param2")); + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1})); + + auto true_builder = HloComputation::Builder(TestName() + "_TrueBranch"); + { + auto param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tshape, "param")); + auto gte0 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, param, 0)); + auto gte1 = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, param, 1)); + auto add = true_builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, gte0, gte1)); + true_builder.AddInstruction(HloInstruction::CreateTuple({add})); + } + HloComputation* true_computation = + module->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_FalseBranch"); + { + Shape xshape = ShapeUtil::MakeShapeWithLayout(F32, {128, 8}, {0, 1}); + false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tshape, "param")); + // Using infeed as layout assignment does not mess up with it. + auto infeed = + false_builder.AddInstruction(HloInstruction::CreateInfeed(xshape, "")); + false_builder.AddInstruction(HloInstruction::CreateTuple({infeed})); + } + HloComputation* false_computation = + module->AddEmbeddedComputation(false_builder.Build()); + builder.AddInstruction(HloInstruction::CreateConditional( + result_tshape, pred, tuple, true_computation, tuple, false_computation)); + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + AssignLayouts(module.get(), &computation_layout); + + const HloInstruction* true_root = true_computation->root_instruction(); + const HloInstruction* false_root = false_computation->root_instruction(); + EXPECT_THAT(true_root->opcode(), HloOpcode::kTuple); + EXPECT_THAT(false_root->opcode(), HloOpcode::kTuple); + + const HloInstruction* true_result = true_root->operand(0); + const HloInstruction* false_result = false_root->operand(0); + EXPECT_TRUE(LayoutUtil::Equal(true_result->shape().layout(), + false_result->shape().layout())); + EXPECT_THAT(false_result->opcode(), HloOpcode::kCopy); } } // namespace diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 2c2a02f6375343d67dfb155bbb03729ff6e490d2..f8b309488eeb5391b1cad5db760934ec1f7e3521 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -35,8 +35,7 @@ class PointsToAnalysisTestBase : public HloTestBase { CHECK_NOTNULL(module_.get()); points_to_analysis_ = TuplePointsToAnalysis::Run(module_.get()).ConsumeValueOrDie(); - dataflow_analysis_ = - HloDataflowAnalysis::Run(module_.get()).ConsumeValueOrDie(); + dataflow_analysis_ = HloDataflowAnalysis::Run(*module_).ConsumeValueOrDie(); } void BuildModuleAndRunAnalysis(std::unique_ptr computation) { diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index f98fc0400a7d827a29dcddc5eecf9a4a01e76590..911b243fe28a5baf8a4b8ed752b892265f5388ac 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -14,12 +14,29 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_compiler.h" +#include "tensorflow/core/platform/denormal.h" + +#ifdef __FAST_MATH__ +#error "Don't build XLA with -ffast-math" +#endif namespace xla { StatusOr>> LLVMCompiler::Compile( std::vector> modules, std::vector> stream_execs, DeviceMemoryAllocator* device_allocator) { + // Tensorflow tries to enable the following behaviors in all its threads: + // + // - Denormals are zero (DAZ): roughly, operations treat denormal floats as + // zero. + // - Flush denormals to zero (FTZ): roughly, operations produce zero instead + // of denormal floats. + // + // In theory enabling these shouldn't matter since the compiler should ideally + // not leak its environment into generated code, but we turn off DAZ and FTZ + // to get some defense-in-depth. + tensorflow::port::ScopedDontFlushDenormal dont_flush_denormals; + std::vector> result; for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index ffc78bd5cfac3df1001d8125327607c85169ae92..37261ed1e665ebed9685751161a412ad114a9e96 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -54,6 +54,7 @@ cc_library( "@llvm//:core", "@llvm//:support", "@llvm//:target", + "@llvm//:transform_utils", ], ) diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 6384c7f46f5ebbedaeda232b40095611a5d738a4..f3642cf0a1c202e785d8e2d3fe469f95eff212c8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -160,7 +160,8 @@ IrArray::Index IrArray::Index::SourceIndexOfReshape( } } - if (linear() != nullptr && + if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape) && ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) { return Index(source_multidim_index, linear(), input_shape); } @@ -195,10 +196,13 @@ IrArray::Index IrArray::Index::SourceIndexOfTranspose( llvm::IRBuilder<>* builder) const { std::vector operand_multidim_index = Permute(dimension_mapping, multidim()); - if (linear() != nullptr && + + if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) && + LayoutUtil::HasLayout(shape) && ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) { return Index(operand_multidim_index, linear(), operand_shape); } + return Index(operand_multidim_index); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 8d1e6338e189a055ac20f09961a783b52600866d..5c1866311d1ae1e0c33ab061ee326d86d647a908 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -61,6 +63,16 @@ llvm::StringRef AsStringRef(tensorflow::StringPiece str) { return llvm::StringRef(str.data(), str.size()); } +std::unique_ptr DropConstantInitializers( + const llvm::Module& module) { + std::unique_ptr cloned_module = CloneModule(module); + for (llvm::GlobalVariable& global_var : cloned_module->globals()) { + global_var.setInitializer(nullptr); + global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage); + } + return cloned_module; +} + string DumpModuleToString(const llvm::Module& module) { std::string buffer_string; llvm::raw_string_ostream ostream(buffer_string); @@ -672,6 +684,19 @@ static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { return uniquer->GetUniqueName(prefix); } +static Status CreateAndWriteStringToFile(const string& directory_name, + const string& file_name, + const string& text) { + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewWritableFile(file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(text)); + TF_RETURN_IF_ERROR(f->Close()); + return Status::OK(); +} + Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized) { @@ -686,13 +711,17 @@ Status DumpIRToDirectory(const string& directory_name, directory_name, tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); - std::unique_ptr f; - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->NewWritableFile(ir_file_name, &f)); - TF_RETURN_IF_ERROR(f->Append(DumpModuleToString(llvm_module))); - return f->Close(); + // For some models the embedded constants can be huge, so also dump the module + // with the constants stripped to get IR that is easier to manipulate. + string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( + directory_name, + tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + + TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( + directory_name, ir_file_name, DumpModuleToString(llvm_module))); + return CreateAndWriteStringToFile( + directory_name, ir_no_constant_initializers_file_name, + DumpModuleToString(*DropConstantInitializers(llvm_module))); } llvm::Function* CreateFunction(llvm::FunctionType* function_type, diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 98dfc89867ab33788c4cc837a66d6751a1ef2507..43d0f605985819afdaf2db2309a0bfb86f230fe3 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -1445,6 +1446,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kFftRequest: handle_status = computation->AddFftInstruction(arg->fft_request()); break; + case OpRequest::kGatherRequest: + handle_status = computation->AddGatherInstruction(arg->gather_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); @@ -1456,6 +1460,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddOutfeedInstruction(arg->outfeed_request()); break; + case OpRequest::kHostComputeRequest: + handle_status = + computation->AddHostComputeInstruction(arg->host_compute_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, @@ -1548,8 +1556,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { case OpRequest::kSendRequest: { TF_RETURN_IF_ERROR( channel_tracker_.RegisterSend(arg->send_request().channel_handle())); - TF_RETURN_IF_ERROR(computation->AddSendInstruction(arg->send_request())); - return tensorflow::Status::OK(); + // Send does not return a value, but we need a handle to be able to + // set OpMetadata and OpSharding (device assignment). + handle_status = computation->AddSendInstruction(arg->send_request()); + break; } case OpRequest::kRecvRequest: { TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 4ba6da6ccc44be8f3c70d2af80b30f0b2e388c2a..607a672025f939a376a32db22acba5bc9168e420 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -209,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, } // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + init_value_shape)) { return InvalidArgument( "Reduction function's accumulator shape differs from the " "init_value shape: %s vs %s", @@ -220,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Check that the inputs can be passed in as the second argument. const Shape& input_element_shape = ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::Compatible(input_element_shape, - reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape differs from the " "input type element type: %s vs %s", @@ -231,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Currently the accumulator and inputs must be the same type, // though that restriction could be relaxed. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape currently must " "match the result shape. Got %s vs %s", @@ -394,11 +396,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, dimension); } const Shape* arg_shape = nullptr; + PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; + element_type = arg_shape->element_type(); continue; } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { @@ -409,7 +413,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } - if (arg_shape->element_type() != shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "cannot concatenate arrays with different element types: %s vs %s", PrimitiveType_Name(arg_shape->element_type()).c_str(), @@ -431,6 +435,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*shape).c_str(), dimension); } } + element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); } std::vector new_dimensions(arg_shape->dimensions().begin(), @@ -438,7 +443,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (size_t i = 1; i < arg_shapes.size(); ++i) { new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); } - return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions); + return ShapeUtil::MakeShape(element_type, new_dimensions); } /* static */ StatusOr ShapeInference::InferConvertShape( @@ -536,7 +541,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } - if (operand_shape.element_type() != padding_value_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + padding_value_shape)) { return InvalidArgument( "the element types of the operands to pad do not match"); } @@ -548,7 +554,9 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, std::max(operand_shape.dimensions(i) - 1, 0LL) * padding_config.dimensions(i).interior_padding(); } - return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); + return ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), + dimensions); } // Current DotDimensionNumbers Requirements: @@ -673,7 +681,7 @@ Status ValidateDotDimensionNumbers( }; // Check if both element types are the same. - if (lhs.element_type() != rhs.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return fail("element types do not match"); } @@ -736,7 +744,8 @@ Status ValidateDotDimensionNumbers( dimensions.push_back(rhs.dimensions(i)); } } - Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions); + Shape result = ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -767,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(rhs).c_str()); } } - return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + output_dimensions); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -829,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // specified in broadcast_dimensions are then changed to match the // corresponding dimension size in smaller_shape. Shape output_shape(larger_shape); + output_shape.set_element_type( + ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape)); for (int i = 0; i < smaller_shape.dimensions_size(); ++i) { int64 dimension_to_match = broadcast_dimensions.at(i); @@ -878,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "binary op %s with different element types: %s and %s", BinaryOperation_Name(operation).c_str(), @@ -897,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } - if (ShapeUtil::Compatible(lhs, rhs)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) { // If the shapes are the same other than layout, the output shape is the // same (elementwise op). - return lhs; + return ShapeUtil::ChangeElementType( + lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -973,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions)); - if (lhs.element_type() == F32) { + if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { return Unimplemented("complex component type not supported"); @@ -1078,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); - if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } if (!ShapeUtil::IsTuple(*arg_shapes[i]) && !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) { + ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; } @@ -1148,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( i, ShapeUtil::HumanString(parameter_shape).c_str()); } - if (parameter_shape.element_type() != arg_shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, + *arg_shape)) { return InvalidArgument( "mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s", @@ -1221,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " @@ -1230,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " @@ -1329,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1339,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1349,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1359,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1481,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " @@ -1490,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " @@ -1499,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1508,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(var_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1569,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s", ShapeUtil::HumanString(lhs).c_str(), @@ -1714,8 +1739,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } - - return ShapeUtil::MakeShape(lhs.element_type(), dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + dimensions); } /* static */ StatusOr ShapeInference::InferFftShape( @@ -1877,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(0))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(0))) { return InvalidArgument( "select function's first parameter shape currently must " "match the operand element shape. Got %s vs %s", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(1))) { return InvalidArgument( "select function's second parameter shape currently must " "match the operand element shape. Got %s vs %s", @@ -1903,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( InferWindowOutputShape(operand_shape, window, operand_shape.element_type(), /*allow_negative_padding=*/false)); - if (!ShapeUtil::Compatible(source_shape, window_result_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, + window_result_shape)) { return InvalidArgument( "source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s)", @@ -2086,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } - if (operand_shape.element_type() != update_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + update_shape)) { return InvalidArgument( "dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s", @@ -2322,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); - if (!ShapeUtil::SameElementType(min, operand) || - !ShapeUtil::SameElementType(max, operand)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || + !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("clamp op with different operand types: %s, %s, %s", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); } - if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && - (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || + ShapeUtil::IsScalar(min)) && + (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) || + ShapeUtil::IsScalar(max)))) { return operand; } if (ShapeUtil::IsScalar(operand)) { - if (ShapeUtil::Compatible(min, max)) { - return min; + if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) { + return ShapeUtil::ChangeElementType(min, operand.element_type()); } else if (ShapeUtil::IsScalar(min)) { - return max; + return ShapeUtil::ChangeElementType(max, operand.element_type()); } else if (ShapeUtil::IsScalar(max)) { - return min; + return ShapeUtil::ChangeElementType(min, operand.element_type()); } } return Unimplemented( @@ -2352,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - if (!ShapeUtil::Compatible(on_true, on_false)) { + bool compatible; + if (ShapeUtil::IsTuple(on_true)) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + compatible = ShapeUtil::Compatible(on_true, on_false); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); + } + if (!compatible) { return InvalidArgument( "operands to select must be the same shape; got %s and %s", ShapeUtil::HumanString(on_true).c_str(), @@ -2367,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return on_true; + return ShapeUtil::ChangeElementType( + on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { return Unimplemented( "select operation with non-scalar predicate with dimensionality " @@ -2410,4 +2448,209 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return to_apply.result(); } +static Status ValidateGatherDimensionNumbers( + const Shape& input_shape, + tensorflow::gtl::ArraySlice gather_indices_shape, + const GatherDimensionNumbers& dim_numbers) { + if (!c_is_sorted(dim_numbers.output_window_dims())) { + return InvalidArgument( + "Output window dimensions in gather op must be ascending; got: %s", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.output_window_dims()) != + dim_numbers.output_window_dims().end()) { + return InvalidArgument( + "Output window dimensions in gather op must not repeat; got: %s", + Join(dim_numbers.output_window_dims(), ", ").c_str()); + } + + const int64 output_window_dim_count = dim_numbers.output_window_dims_size(); + const int64 output_shape_rank = + output_window_dim_count + gather_indices_shape.size() - 1; + + for (int i = 0; i < dim_numbers.output_window_dims_size(); ++i) { + int64 window_index = dim_numbers.output_window_dims(i); + if (window_index < 0 || window_index >= output_shape_rank) { + return InvalidArgument( + "Window index %d in gather op is out of bounds; got %lld, but should " + "have been in [0,%lld)", + i, window_index, output_shape_rank); + } + } + + if (dim_numbers.gather_dims_to_operand_dims_size() != + gather_indices_shape[dim_numbers.index_vector_dim()]) { + return InvalidArgument( + "Gather op has %d elements in gather_dims_to_operand_dims and the " + "bound of dimension index_vector_dim=%lld of gather_indices is " + "%lld. These two numbers must be equal.", + dim_numbers.gather_dims_to_operand_dims_size(), + dim_numbers.index_vector_dim(), + gather_indices_shape[dim_numbers.index_vector_dim()]); + } + + for (int i = 0; i < dim_numbers.gather_dims_to_operand_dims_size(); i++) { + int64 gather_dim_to_input_dim = dim_numbers.gather_dims_to_operand_dims(i); + if (gather_dim_to_input_dim < 0 || + gather_dim_to_input_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid gather_dims_to_operand_dims mapping; domain is [0, %d), " + "got: %d->%lld", + input_shape.dimensions_size(), i, gather_dim_to_input_dim); + } + } + + std::vector sorted_gather_dims_to_operand_dims( + dim_numbers.gather_dims_to_operand_dims().begin(), + dim_numbers.gather_dims_to_operand_dims().end()); + + c_sort(sorted_gather_dims_to_operand_dims); + + if (c_adjacent_find(sorted_gather_dims_to_operand_dims) != + sorted_gather_dims_to_operand_dims.end()) { + return InvalidArgument( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims; " + "got: %s", + Join(dim_numbers.gather_dims_to_operand_dims(), ", ").c_str()); + } + + for (int64 elided_dim : dim_numbers.elided_window_dims()) { + if (elided_dim < 0 || elided_dim >= input_shape.dimensions_size()) { + return InvalidArgument( + "Invalid elided_window_dims set in gather op; valid range is [0, " + "%d), got: %lld", + input_shape.dimensions_size(), elided_dim); + } + } + + if (!c_is_sorted(dim_numbers.elided_window_dims())) { + return InvalidArgument( + "elided_window_dims in gather op must be sorted; got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + if (c_adjacent_find(dim_numbers.elided_window_dims()) != + dim_numbers.elided_window_dims().end()) { + return InvalidArgument( + "Repeated dimensions not allowed in elided_window_dims in gather op; " + "got: %s", + Join(dim_numbers.elided_window_dims(), ", ").c_str()); + } + + return Status::OK(); +} + +/*static*/ StatusOr ShapeInference::InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(input_shape, "input tensor operand gather op")); + TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque( + gather_indices_shape, "gather indices operand of gather op")); + + if (!ShapeUtil::ElementIsIntegral(gather_indices_shape)) { + return InvalidArgument( + "Gather indices parameter must be an integral tensor; got %s", + ShapeUtil::HumanString(gather_indices_shape).c_str()); + } + + // We implicitly reshape gather indices of shape P[A,B,C] to P[A,B,C,1] if + // index_vector_dim is rank(P). The bounds of this expanded shape is + // stored in expanded_gather_indices_shape. + + if (gather_indices_shape.dimensions_size() < + gather_dim_numbers.index_vector_dim() || + gather_dim_numbers.index_vector_dim() < 0) { + return InvalidArgument( + "Gather index leaf dimension must be within [0, rank(gather_indices) + " + "1). rank(gather_indices) is %d and gather index leaf dimension is " + "%lld.", + gather_indices_shape.dimensions_size(), + gather_dim_numbers.index_vector_dim()); + } + + std::vector expanded_gather_indices_shape; + expanded_gather_indices_shape.reserve(gather_indices_shape.dimensions_size()); + c_copy(gather_indices_shape.dimensions(), + std::back_inserter(expanded_gather_indices_shape)); + if (expanded_gather_indices_shape.size() == + gather_dim_numbers.index_vector_dim()) { + expanded_gather_indices_shape.push_back(1); + } + + TF_RETURN_IF_ERROR(ValidateGatherDimensionNumbers( + input_shape, expanded_gather_indices_shape, gather_dim_numbers)); + + if (window_bounds.size() != input_shape.dimensions_size()) { + return InvalidArgument( + "Gather op must have one window bound for every input dimension; got: " + "len(window_bounds)=%lu, input_shape.rank=%d", + window_bounds.size(), input_shape.dimensions_size()); + } + + if (window_bounds.size() != + gather_dim_numbers.output_window_dims_size() + + gather_dim_numbers.elided_window_dims_size()) { + return InvalidArgument( + "All components of the window index in a gather op must either be a " + "output window index or explicitly elided; got len(window_bounds)=%lu, " + "output_window_bounds=%s, elided_window_bounds=%s", + window_bounds.size(), + Join(gather_dim_numbers.output_window_dims(), ",").c_str(), + Join(gather_dim_numbers.elided_window_dims(), ",").c_str()); + } + + for (int i = 0; i < window_bounds.size(); i++) { + int64 window_bound = window_bounds[i]; + int64 corresponding_input_bound = input_shape.dimensions(i); + if (window_bound < 0 || window_bound > corresponding_input_bound) { + return InvalidArgument( + "Window bound at index %d in gather op is out of range, must be " + "within " + "[0, %lld), got %lld", + i, corresponding_input_bound + 1, window_bound); + } + } + + for (int i = 0; i < gather_dim_numbers.elided_window_dims_size(); i++) { + if (window_bounds[gather_dim_numbers.elided_window_dims(i)] != 1) { + return InvalidArgument( + "Gather op can only elide window indices with bound 1, but bound is " + "%lld for index %lld at position %d", + window_bounds[gather_dim_numbers.elided_window_dims(i)], + gather_dim_numbers.elided_window_dims(i), i); + } + } + + int64 result_rank = gather_dim_numbers.output_window_dims_size() + + (expanded_gather_indices_shape.size() - 1); + int64 window_dims_seen = 0; + int64 gather_dims_seen = 0; + std::vector output_dim_bounds; + output_dim_bounds.reserve(result_rank); + for (int64 i = 0; i < result_rank; i++) { + int64 current_bound; + bool is_window_index = + c_binary_search(gather_dim_numbers.output_window_dims(), i); + if (is_window_index) { + while (c_binary_search(gather_dim_numbers.elided_window_dims(), + window_dims_seen)) { + window_dims_seen++; + } + current_bound = window_bounds[window_dims_seen++]; + } else { + if (gather_dims_seen == gather_dim_numbers.index_vector_dim()) { + gather_dims_seen++; + } + current_bound = expanded_gather_indices_shape[gather_dims_seen++]; + } + + output_dim_bounds.push_back(current_bound); + } + + return ShapeUtil::MakeShape(input_shape.element_type(), output_dim_bounds); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index b39151ebbc19f5d0b702a80da5069f58c8dfb07d..0d3045213db2230da3e18ffcb1a9923250560b64 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -37,6 +37,11 @@ namespace xla { // the expected result type for computations that are built up via the API -- // the shape that results from an operation is inferred. Some methods have // overloads for inferring shape at the HLO level. +// +// TODO(b/73352135): Shape inference does not issue very good error messages, in +// part because HloInstruction::ToString() is not available since shape +// inference runs before the HloInstruction object is created. We need a +// solution for this. class ShapeInference { public: // Infers the shape produced by applying the given unary operation to the @@ -248,6 +253,14 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const DotDimensionNumbers& dimension_numbers); + // Helper that infers the shape of the tensor produced by a gather operation + // with the given input shape, gather indices shape and gather dimension + // numbers. + static StatusOr InferGatherShape( + const Shape& input_shape, const Shape& gather_indices_shape, + const GatherDimensionNumbers& gather_dim_numbers, + tensorflow::gtl::ArraySlice window_bounds); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 026c021165785bd3945d6a846dae446ad45da9b7..029d2b3b86c00796db5e67b46490c8b178d571ec 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -18,15 +18,16 @@ limitations under the License. #include #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace { +using ::tensorflow::gtl::ArraySlice; using ::testing::ContainsRegex; using ::testing::HasSubstr; @@ -1527,5 +1528,458 @@ TEST_F(ShapeInferenceTest, BadSlice) { << statusor.status(); } +class GatherShapeInferenceTest : public ShapeInferenceTest { + protected: + const Shape s64_scalar_ = ShapeUtil::MakeShape(S64, {}); + const Shape s64_vector_5_ = ShapeUtil::MakeShape(S64, {5}); + const Shape s64_vector_32_ = ShapeUtil::MakeShape(S64, {32}); + const Shape s64_4d_tensor_10_9_8_7_1_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 1}); + const Shape s64_4d_tensor_10_9_8_7_5_ = + ShapeUtil::MakeShape(S64, {10, 9, 8, 7, 5}); + const Shape s64_4d_tensor_5_10_9_7_6_ = + ShapeUtil::MakeShape(S64, {5, 10, 9, 7, 6}); + const Shape s64_4d_tensor_10_9_5_7_6_ = + ShapeUtil::MakeShape(S64, {10, 9, 5, 7, 6}); + const Shape f32_5d_tensor_50_49_48_47_46_ = + ShapeUtil::MakeShape(F32, {50, 49, 48, 47, 46}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {s64_4d_tensor_10_9_8_7_1_, s64_4d_tensor_10_9_8_7_1_}); +}; + +TEST_F(GatherShapeInferenceTest, TensorFlowGather) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {64, 32}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherV2) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{1}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/1), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE( + ShapeUtil::Equal(gather_shape, ShapeUtil::MakeShape(F32, {32, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowGatherNd) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape(matrix_64_48_, s64_4d_tensor_10_9_8_7_1_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 48})); + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 48}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TensorFlowBatchDynamicSlice) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26})); + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 8, 7, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_A) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/2), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NonDefaultGatherIndicesLeafDim_B) { + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_5_10_9_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal( + gather_shape, + ShapeUtil::MakeShape(F32, {10, 9, 7, 6, 30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, NoOutputGatherDims) { + // This is equivalent to a dynamic slice. + TF_ASSERT_OK_AND_ASSIGN( + Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_vector_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3, 4}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/0), + /*window_bounds=*/{30, 29, 28, 27, 26})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27, 26}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, ScalarGatherIndices) { + // The gather indices "tensor" is a scalar S here that's used to slice out + // [S,0,0,0,0]..[S,30,29,28,27] into a [30,29,28,27] shaped result. + TF_ASSERT_OK_AND_ASSIGN(Shape gather_shape, + ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_scalar_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{0, 1, 2, 3}, + /*elided_window_dims=*/{0}, + /*gather_dims_to_operand_dims=*/{0}, + /*index_vector_dim=*/0), + /*window_bounds=*/{1, 30, 29, 28, 27})); + + EXPECT_TRUE(ShapeUtil::Equal(gather_shape, + ShapeUtil::MakeShape(F32, {30, 29, 28, 27}))) + << ShapeUtil::HumanString(gather_shape); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedTensorInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + tuple_shape_, s64_vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/1), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for input")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, TupleShapedGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, tuple_shape_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Expected non-tuple argument for gather indices")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, FloatingPointGatherIndicesInput) { + StatusOr statusor = ShapeInference::InferGatherShape( + s64_vector_32_, vector_32_, + HloInstruction::MakeGatherDimNumbers(/*output_window_dims=*/{0}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{1}, + /*index_vector_dim=*/0), + /*window_bounds=*/{64, 1}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather indices parameter must be an integral tensor")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingWindowIndices) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 8, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must be ascending")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowIndices) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 7}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Output window dimensions in gather op must not repeat")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 99, 100, 101}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 2 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowIndexBarelyOutOfBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 9}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window index 4 in gather op is out of bounds")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingElidedWindowDims) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{4}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("All components of the window index in a gather op must either " + "be a output window index or explicitly elided")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsWindowToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 19}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Invalid elided_window_dims set in gather op; valid " + "range is [0, 5), got: 19")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedWindowToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{0, 1, 2, 3, 3}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions not allowed in elided_window_dims in gather op")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Gather op has 4 elements in gather_dims_to_operand_dims and " + "the bound of dimension index_vector_dim=4 of " + "gather_indices is 5. These two numbers must be equal.")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_OutOfBoundsGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 7}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr("Invalid gather_dims_to_operand_dims mapping; domain is " + "[0, 5), got: 4->7")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_RepeatedGatherToInputMapping) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 3}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Repeated dimensions are not allowed in gather_dims_to_operand_dims")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_NonAscendingElidedWindowDims) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{2, 1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{1, 1, 28, 27, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("elided_window_dims in gather op must be sorted")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, InvalidGatherDimNumbers_WindowBoundsTooLarge) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{2}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 1, 300, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Window bound at index 3 in gather op is out of range, " + "must be within [0, 48), got 300")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_MismatchingNumberOfWindowBounds) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 26}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT( + statusor.status().error_message(), + HasSubstr( + "Gather op must have one window bound for every input dimension")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, + InvalidGatherDimNumbers_WindowBoundsNot1ForElidedDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_8_7_5_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7}, + /*elided_window_dims=*/{1}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/4), + /*window_bounds=*/{30, 29, 28, 26, 20}); + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather op can only elide window indices with bound 1, " + "but bound is 29 for index 1 at position 0")) + << statusor.status(); +} + +TEST_F(GatherShapeInferenceTest, OutOfBoundsGatherIndicesLeafDim) { + StatusOr statusor = ShapeInference::InferGatherShape( + f32_5d_tensor_50_49_48_47_46_, s64_4d_tensor_10_9_5_7_6_, + HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/{4, 5, 6, 7, 8}, + /*elided_window_dims=*/{}, + /*gather_dims_to_operand_dims=*/{0, 1, 2, 3, 4}, + /*index_vector_dim=*/32), + /*window_bounds=*/{30, 29, 28, 27, 26}); + + ASSERT_FALSE(statusor.ok()); + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("Gather index leaf dimension must be within [0, " + "rank(gather_indices) + 1)")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index c679d401c3691b14a43ce77cbe953cd4c64a9e92..6e9986165f7eaf71a964b42b734a5ae5db5e45d7 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -41,7 +41,32 @@ ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, on_device_shape_(on_device_shape), platform_(platform), device_ordinal_(device_ordinal), - buffers_(on_device_shape) {} + buffers_(&on_device_shape_) {} + +ShapedBuffer::ShapedBuffer(ShapedBuffer&& s) + : on_host_shape_(std::move(s.on_host_shape_)), + on_device_shape_(std::move(s.on_device_shape_)), + platform_(s.platform_), + device_ordinal_(s.device_ordinal_), + buffers_(std::move(s.buffers_)) { + // s.buffers_ has a pointer to s.on_device_shape_. When we move s.buffers_ + // into buffers_, we also need to update this pointer so that buffers_ doesn't + // point into s. + buffers_.replace_shape_ptr(&on_device_shape_); +} + +ShapedBuffer& ShapedBuffer::operator=(ShapedBuffer&& s) { + on_host_shape_ = std::move(s.on_host_shape_); + on_device_shape_ = std::move(s.on_device_shape_); + platform_ = s.platform_; + device_ordinal_ = s.device_ordinal_; + buffers_ = std::move(s.buffers_); + // buffers_ has a pointer to its on_device_shape_. When we move s.buffers_ + // into buffers_, we also need to update this pointer so that buffers_ doesn't + // point into s. + buffers_.replace_shape_ptr(&on_device_shape_); + return *this; +} void ShapedBuffer::clear() { for (auto& pair : buffers_) { @@ -99,6 +124,10 @@ ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, device_ordinal), allocator_(allocator) {} +ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer, + DeviceMemoryAllocator* allocator) + : ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {} + ScopedShapedBuffer::~ScopedShapedBuffer() { // Deallocate all non-null buffers. A buffer may appear in more than one spot // in the shape (eg, a tuple with a repeated element) so keep track of what @@ -116,12 +145,8 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = MakeUnique( - on_host_shape(), on_device_shape(), platform(), device_ordinal()); - - shaped_buffer->buffers() = buffers(); - clear(); - + auto shaped_buffer = MakeUnique(std::move(*this)); + buffers_ = ShapeTree(); return shaped_buffer; } diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index d397e47d2ca734458c7dc99baa5c81b16d0fd72b..b816df8385ef65b0b69ede1d6e65a1991b4bd7c6 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -87,18 +87,24 @@ class ShapedBuffer { string ToString() const; + ShapedBuffer(ShapedBuffer&& s); + ShapedBuffer& operator=(ShapedBuffer&&); + protected: + ShapedBuffer(const ShapedBuffer&) = delete; + ShapedBuffer& operator=(const ShapedBuffer&) = delete; + // The shape of the data when represented on the host. - const Shape on_host_shape_; + Shape on_host_shape_; // The shape of the data on the device. - const Shape on_device_shape_; + Shape on_device_shape_; // The platform the memory is allocated on. const perftools::gputools::Platform* platform_; // The device the memory is allocated on. - const int device_ordinal_; + int device_ordinal_; // The tree of device buffers. Its shape is on_device_shape(). ShapeTree buffers_; @@ -121,14 +127,20 @@ class ScopedShapedBuffer : public ShapedBuffer { ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, DeviceMemoryAllocator* allocator, int device_ordinal); + // Create a ScopedShapedBuffer by taking over the memory from the incoming + // ShapedBuffer. + ScopedShapedBuffer(ShapedBuffer shaped_buffer, + DeviceMemoryAllocator* allocator); + // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } - // Release all device memory owned by this ScopedShapedBuffer and return the - // device memory pointers in the form of a ShapedBuffer. Device memory - // pointers in this ScopedShapedBuffer object are set to null. This method is - // analogous to std::unique_ptr::release(). + // Release all device memory owned by this ScopedShapedBuffer and + // return the device memory pointers in the form of a + // ShapedBuffer. The returned ShapedBuffer takes over the memory + // from the ScopedShapedBuffer. The resulting ScopedShapedBuffer can + // only be destroyed. std::unique_ptr release(); // All buffers in the shape are deallocated on destruction. diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index fead9b92362bcd1974f2dff6e030bc47dfc5aa85..06735e9442942f3c69d1cd679857fe22f2fa6756 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -226,7 +226,8 @@ StatusOr UserComputation::AddParameterInstruction( return handle; } -Status UserComputation::AddSendInstruction(const SendRequest& send_request) { +StatusOr UserComputation::AddSendInstruction( + const SendRequest& send_request) { tensorflow::mutex_lock lock(mutex_); // Check if the operand of the instruction is valid. @@ -244,7 +245,7 @@ Status UserComputation::AddSendInstruction(const SendRequest& send_request) { VLOG(1) << "AddSendInstruction (" << GetVersionedHandleInternal() << "), data handle " << handle.handle() << ": " << send_request.ShortDebugString(); - return Status::OK(); + return handle; } StatusOr UserComputation::AddRecvInstruction( @@ -315,6 +316,36 @@ StatusOr UserComputation::AddConstantInstruction( return handle; } +StatusOr UserComputation::AddGatherInstruction( + const GatherRequest& gather_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* input_request, + LookUpRequest(gather_request.input())); + TF_ASSIGN_OR_RETURN(const OperationRequest* gather_indices_request, + LookUpRequest(gather_request.gather_indices())); + + TF_ASSIGN_OR_RETURN( + Shape shape, + ShapeInference::InferGatherShape( + input_request->output_shape(), gather_indices_request->output_shape(), + gather_request.dimension_numbers(), + AsInt64Slice(gather_request.window_bounds()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_gather_request() = gather_request; + + VLOG(1) << "AddGatherInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << gather_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddGetTupleElementInstruction( const GetTupleElementRequest& get_tuple_element_request) { tensorflow::mutex_lock lock(mutex_); @@ -1276,6 +1307,28 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddHostComputeInstruction( + const HostComputeRequest& host_compute_request) { + tensorflow::mutex_lock lock(mutex_); + + for (const ComputationDataHandle& handle : host_compute_request.operands()) { + TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); + } + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = host_compute_request.shape(); + *request.mutable_request()->mutable_host_compute_request() = + host_compute_request; + + VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << host_compute_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddDotInstruction( const DotRequest& dot_request) { tensorflow::mutex_lock lock(mutex_); @@ -1713,6 +1766,11 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kHostComputeRequest: { + *is_functional = false; + break; + } + case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { @@ -1991,6 +2049,16 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kGatherRequest: { + PureFunctionalVisitor(session_computation, + request.request().gather_request().input(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + request.request().gather_request().gather_indices(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -2643,6 +2711,15 @@ static void ForEachOperand( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& hc_request = + request.request().host_compute_request(); + for (const ComputationDataHandle& operand : hc_request.operands()) { + apply(operand); + } + break; + } + case OpRequest::kDotRequest: { const DotRequest& dot_request = request.request().dot_request(); apply(dot_request.rhs()); @@ -2684,6 +2761,13 @@ static void ForEachOperand( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + apply(gather_request.input()); + apply(gather_request.gather_indices()); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; @@ -3299,6 +3383,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& host_compute_request = + request.request().host_compute_request(); + std::vector operands; + for (const ComputationDataHandle& operand : + host_compute_request.operands()) { + operands.push_back(lookup_instruction(operand)); + } + auto output_shape = host_compute_request.shape(); + auto channel_name = host_compute_request.channel_name(); + auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); + hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( + output_shape, operands, channel_name, cost_estimate_ns)); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -3401,6 +3501,20 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kGatherRequest: { + const GatherRequest& gather_request = request.request().gather_request(); + HloInstruction* input_operand = + lookup_instruction(gather_request.input()); + HloInstruction* gather_indices_operand = + lookup_instruction(gather_request.gather_indices()); + std::vector window_bounds; + c_copy(gather_request.window_bounds(), std::back_inserter(window_bounds)); + hlo_instruction = add_instruction(HloInstruction::CreateGather( + request.output_shape(), input_operand, gather_indices_operand, + gather_request.dimension_numbers(), window_bounds)); + break; + } + case OpRequest::OP_NOT_SET: LOG(FATAL) << "OperationRequest doesn't contain a request"; diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 54bb24d6d7fe7aa8cc7c684795e40464e4eb6614..5544c868fe905c1ca7e6cab32738440add2e3b4f 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -149,6 +149,10 @@ class UserComputation { StatusOr AddOutfeedInstruction( const OutfeedRequest& outfeed_request); + // Enqueues a host compute instruction onto this user computation. + StatusOr AddHostComputeInstruction( + const HostComputeRequest& host_compute_request); + // Enqueues a call instruction onto this user computation. StatusOr AddCallInstruction( const CallRequest& call_request, @@ -232,12 +236,17 @@ class UserComputation { const UserComputation& false_computation); // Enqueues a Send instruction onto this user computation. - Status AddSendInstruction(const SendRequest& send_request); + StatusOr AddSendInstruction( + const SendRequest& send_request); // Enqueues a Recv instruction onto this user computation. StatusOr AddRecvInstruction( const RecvRequest& recv_request); + // Enqueues a Gather instruction onto this user computation. + StatusOr AddGatherInstruction( + const GatherRequest& gather_request); + // Returns the user-provided name of this user computation, which is provided // via the XLA computation-building API. const string& name() const { return name_; } diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc index a5f9b01f011ce04f1114c74391a967c62f015221..3ef0cdff6751258e4489ce350deb0931fdf69ef9 100644 --- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -106,20 +106,12 @@ static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kConstant: + case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kSlice: + case HloOpcode::kTranspose: case HloOpcode::kTuple: return true; - - case HloOpcode::kTranspose: - return ShapeUtil::TransposeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape(), instruction.dimensions()); - - case HloOpcode::kReshape: - return ShapeUtil::ReshapeIsBitcast( - /*input_shape=*/instruction.operand(0)->shape(), - /*output_shape=*/instruction.shape()); } } diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index d752619bd65751779c24f061e44e206d66b01465..280f02e88675381bd75108bfae0dd22c462ba718 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -143,6 +143,18 @@ class ShapeTree { // Return the shape represented with this ShapeTree. const Shape& shape() const { return *shape_; } + // Replaces *only* the underlying shape of this ShapeTree. The caller must own + // the Shape object and hence shape_storage_ is not updated. + // + // Only safe to use this if the ShapeTree was constructed with 'explicit + // ShapeTree(const Shape* shape)' or is moved from one such ShapeTree. The + // caller must ensure that the input shape is consistent with the underlying + // tree. + void replace_shape_ptr(const Shape* shape) { + CHECK(shape_storage_.get() == nullptr); + shape_ = shape; + } + // Returns true if the node at the given index is a leaf node (an array // shape). bool IsLeaf(const ShapeIndex& index) const { diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d63e16ce2bf51cff0d113640d31ec6e70bfaf421..315278901638ac2efa991fdeb7ca76f369321288 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -630,6 +630,19 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); + } + if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { + return CompatibleIgnoringElementType(lhs, rhs); + } + return false; +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); @@ -1060,11 +1073,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping) { - // Can't insert bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) && - !LayoutUtil::HasLayout(output_shape)) { - return false; - } + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); // Padding is not handled. if (LayoutUtil::IsPadded(input_shape) && LayoutUtil::IsPadded(output_shape)) { @@ -1093,11 +1103,8 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, /* static */ bool ShapeUtil::ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape) { - // Can't convert reshapes into bitcasts without layout information. - if (!LayoutUtil::HasLayout(input_shape) || - !LayoutUtil::HasLayout(output_shape)) { - return false; - } + CHECK(LayoutUtil::HasLayout(input_shape) && + LayoutUtil::HasLayout(output_shape)); // Padding is not handled. if (LayoutUtil::IsPadded(input_shape) || LayoutUtil::IsPadded(output_shape)) { diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 453d4ec04726a4dd3851b8becb439bb7506e4ca9..8ee263fe5e5fc20edf6d8ce1f56fe72b27b645d0 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -23,6 +23,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -62,6 +63,9 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } + // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + void push_front(int64 value) { indices_.insert(indices_.begin(), value); } + std::vector::const_iterator begin() const { return indices_.begin(); } std::vector::const_iterator end() const { return indices_.end(); } std::vector::iterator begin() { return indices_.begin(); } @@ -211,6 +215,31 @@ class ShapeUtil { return lhs.element_type() == rhs.element_type(); } + // As SameElementType, but allows floating point types to have different + // precisions. + static bool SameElementTypeIgnoringFpPrecision(const Shape& a, + const Shape& b) { + if (ElementIsFloating(a) && ElementIsFloating(b)) { + return true; + } + return ShapeUtil::SameElementType(a, b); + } + + // Returns the higher-precision element type if a and b are both floating + // point types; otherwise, checks that that they have the same element type + // and returns it. + static PrimitiveType HigherPrecisionElementType(const Shape& a, + const Shape& b) { + if (SameElementType(a, b)) { + return a.element_type(); + } + CHECK(SameElementTypeIgnoringFpPrecision(a, b)); + return primitive_util::BitWidth(a.element_type()) < + primitive_util::BitWidth(b.element_type()) + ? b.element_type() + : a.element_type(); + } + // Returns true if the rank, dimension sizes, and element type are // identical. Layout is ignored. Tuple elements are compared recursively for // compatibility. @@ -221,6 +250,10 @@ class ShapeUtil { // compatibility. static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // As Compatible, but allow one of lhs and rhs to be BF16 while the other + // being F32. Tuple elements are compared recursively for compatibility. + static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); @@ -489,12 +522,16 @@ class ShapeUtil { // Returns whether a transpose from input_shape to output_shape with dimension // mapping "dimension_mapping" produces a result which is bit-wise identical // to its input and thus may be replaced with a bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool TransposeIsBitcast( const Shape& input_shape, const Shape& output_shape, tensorflow::gtl::ArraySlice dimension_mapping); // Returns whether a reshape from "input_shape" to "output_shape" is a // bitcast. + // + // Precondition: Both input_shape and output_shape have explicit layouts. static bool ReshapeIsBitcast(const Shape& input_shape, const Shape& output_shape); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 81ba7afb95265398e830e26122cd0056a32daee3..4db97d45b20b86dc60531845c6e28a223203ff7f 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -170,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); + ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + +TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2}); + ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2}); @@ -184,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) { EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); @@ -193,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 60f3e6180746f2761f093a4b54b136ac4a841031..19b3dfae4ee8cfe2be8ea6be82d6f2fc25e67274 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -271,6 +271,9 @@ cc_library( xla_test( name = "bad_rng_shape_validation_test", srcs = ["bad_rng_shape_validation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -290,6 +293,9 @@ xla_test( xla_test( name = "check_execution_arity_test", srcs = ["check_execution_arity_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -309,6 +315,9 @@ xla_test( xla_test( name = "query_inferred_shape_test", srcs = ["query_inferred_shape_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -366,6 +375,9 @@ xla_test( xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -430,6 +442,9 @@ xla_test( xla_test( name = "pred_test", srcs = ["pred_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla/client:computation_builder", @@ -444,6 +459,9 @@ xla_test( xla_test( name = "select_test", srcs = ["select_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -460,6 +478,7 @@ xla_test( xla_test( name = "conditional_test", srcs = ["conditional_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -476,6 +495,7 @@ xla_test( xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", @@ -586,6 +606,7 @@ xla_test( tags = [ "enormous", "manual", + "notap", ], deps = [ ":client_library_test_base", @@ -621,8 +642,10 @@ xla_test( xla_test( name = "dot_operation_test", srcs = ["dot_operation_test.cc"], + shard_count = 20, tags = [ "enable_for_xla_interpreter", + "optonly", ], deps = [ "//tensorflow/compiler/xla:array2d", @@ -641,32 +664,7 @@ xla_test( ], ) -# Tests the dot operation in some cases that can be performed via a -# runtime call on some backends - e.g. a runtime call to Eigen. -xla_test( - name = "dot_operation_runtime_test", - srcs = ["dot_operation_test.cc"], - tags = [ - "enable_for_xla_interpreter", - ], - deps = [ - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:array3d", - "//tensorflow/compiler/xla:reference_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/client:local_client", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:test_utils", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:framework_internal", - "//tensorflow/core:lib", - "//tensorflow/core:test", - ], -) - -# Repeat dot_operation_runtime_test with single-threded eigen. +# Repeat dot_operation_runtime_test with single-threaded eigen. xla_test( name = "dot_operation_single_threaded_runtime_test", srcs = ["dot_operation_test.cc"], @@ -678,6 +676,8 @@ xla_test( "--xla_cpu_multi_thread_eigen=false", ], }, + shard_count = 20, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -698,6 +698,9 @@ xla_test( xla_test( name = "transpose_test", srcs = ["transpose_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -716,6 +719,9 @@ xla_test( xla_test( name = "constants_test", srcs = ["constants_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -876,8 +882,7 @@ xla_test( name = "half_test", srcs = ["half_test.cc"], backends = [ - # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25 - # "cpu", + "cpu", "gpu", ], deps = [ @@ -901,6 +906,9 @@ xla_test( name = "slice_test", srcs = ["slice_test.cc"], shard_count = 40, + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:reference_util", @@ -917,6 +925,9 @@ xla_test( xla_test( name = "multidimensional_slice_test", srcs = ["multidimensional_slice_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -979,6 +990,9 @@ xla_test( xla_test( name = "vector_ops_reduce_test", srcs = ["vector_ops_reduce_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -997,6 +1011,10 @@ xla_test( name = "reduce_test", srcs = ["reduce_test.cc"], shard_count = 40, + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1090,6 +1108,9 @@ xla_test( xla_test( name = "copy_test", srcs = ["copy_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla:array2d", @@ -1108,6 +1129,9 @@ xla_test( xla_test( name = "reduce_hlo_test", srcs = ["reduce_hlo_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ ":client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1121,6 +1145,9 @@ xla_test( xla_test( name = "call_test", srcs = ["call_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1158,6 +1185,9 @@ xla_test( xla_test( name = "binop_scaling_test", srcs = ["binop_scaling_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1174,6 +1204,9 @@ xla_test( xla_test( name = "broadcast_simple_test", srcs = ["broadcast_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1191,6 +1224,9 @@ xla_test( xla_test( name = "pad_test", srcs = ["pad_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1211,6 +1247,9 @@ xla_test( xla_test( name = "fmax_test", srcs = ["fmax_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1224,6 +1263,9 @@ xla_test( xla_test( name = "log_test", srcs = ["log_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1237,6 +1279,9 @@ xla_test( xla_test( name = "matrix_ops_simple_test", srcs = ["matrix_ops_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1279,6 +1324,9 @@ xla_test( name = "reshape_test", srcs = ["reshape_test.cc"], shard_count = 30, + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1305,6 +1353,9 @@ xla_test( xla_test( name = "reverse_test", srcs = ["reverse_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array4d", @@ -1321,6 +1372,9 @@ xla_test( xla_test( name = "vector_ops_simple_test", srcs = ["vector_ops_simple_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:shape_util", @@ -1344,6 +1398,9 @@ xla_test( xla_test( name = "concat_test", srcs = ["concat_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -1364,8 +1421,12 @@ xla_test( xla_test( name = "convert_test", srcs = ["convert_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1420,6 +1481,9 @@ xla_test( xla_test( name = "floor_ceil_test", srcs = ["floor_ceil_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", @@ -1503,6 +1567,9 @@ xla_test( xla_test( name = "replay_test", srcs = ["replay_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -1525,6 +1592,9 @@ xla_test( xla_test( name = "broadcast_test", srcs = ["broadcast_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", @@ -1592,6 +1662,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1618,6 +1689,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 87ac7731ba69b79a73b5f2b2f360c7fc6ae1198f..8b35259013200e96807446803c696451a8db80a9 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -101,6 +101,33 @@ XLA_TEST_F(ArrayElementwiseOpTest, NegConstantC64) { {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, NegConstantS64) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({ + -1, + 1, + 0, + 0x12345678, + static_cast(0xffffffff12345678l), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + }); + auto result = builder.Neg(a); + LOG(INFO) << -static_cast(0x7FFFFFFFFFFFFFFFLL); + + ComputeAndCompareR1(&builder, + { + 1, + -1, + 0, + -0x12345678, + 0xedcba988, + static_cast(0x8000000000000000LL), + -static_cast(0x8000000000000001LL), + }, + {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, IsFiniteZeroElementF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); @@ -186,6 +213,86 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) { ComputeAndCompareR1(&builder, {}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) { + ComputationBuilder b(client_, TestName()); + + std::vector lhs{0xFFFFFFFF, + static_cast(-1), + 0, + 0, + 0x7FFFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFLL, + 0x8000000000000000LL, + 0x8000000000000000LL, + 1}; + std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_data = + client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + + std::vector rhs{1, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 0x8000000000000000LL, + 0, + static_cast(-1), + 0, + 1, + 0x8000000000000000LL}; + std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_data = + client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + + auto add = b.Add(lhs_param, rhs_param); + + std::vector expected(lhs.size()); + for (int64 i = 0; i < lhs.size(); ++i) { + expected[i] = lhs[i] + rhs[i]; + } + + ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) { + ComputationBuilder b(client_, TestName()); + + std::vector lhs{static_cast(0x8000000000000000LL), + static_cast(0x8000000000000000LL), + -1, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 1, + 0, + -1}; + std::unique_ptr lhs_literal = Literal::CreateR1({lhs}); + auto lhs_param = b.Parameter(0, lhs_literal->shape(), "lhs_param"); + std::unique_ptr lhs_data = + client_->TransferToServer(*lhs_literal).ConsumeValueOrDie(); + + std::vector rhs{-1, + 0, + static_cast(0x8000000000000000LL), + 1, + 0, + 0x7FFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL, + 0x7FFFFFFFFFFFFFFFLL}; + std::unique_ptr rhs_literal = Literal::CreateR1({rhs}); + auto rhs_param = b.Parameter(1, rhs_literal->shape(), "rhs_param"); + std::unique_ptr rhs_data = + client_->TransferToServer(*rhs_literal).ConsumeValueOrDie(); + + auto sub = b.Sub(lhs_param, rhs_param); + + std::vector expected(lhs.size()); + for (int64 i = 0; i < lhs.size(); ++i) { + expected[i] = lhs[i] - rhs[i]; + } + + ComputeAndCompareR1(&b, expected, {lhs_data.get(), rhs_data.get()}); +} + TEST_P(ArrayElementwiseOpTestParamCount, AddManyValues) { const int count = GetParam(); ComputationBuilder builder(client_, TestName()); @@ -847,68 +954,76 @@ XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x12345678), - static_cast(0xF0001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto a = builder.ConstantR1({static_cast(0x12345678), + static_cast(0xF0001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, -1}); auto out = builder.ShiftLeft(a, b); - ComputeAndCompareR1( - &builder, - {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); + ComputeAndCompareR1(&builder, + {static_cast(0x23456780), 0x00100000, 0x4, + 0x180, 2523136, 0, 0, 0}, + {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto a = builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, -1}); auto out = builder.ShiftRightArithmetic(a, b); - ComputeAndCompareR1(&builder, - {static_cast(0xF9234567), - static_cast(0x00100010), 0, 0, 19}, - {}); + ComputeAndCompareR1( + &builder, + {static_cast(0xF9234567), static_cast(0x00100010), 0, 0, 19, + 0, -1, 0}, + {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { ComputationBuilder builder(client_, TestName()); - auto a = - builder.ConstantR1({static_cast(0x92345678), - static_cast(0x10001000), 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto a = builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77, + 1, -3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, -1}); auto out = builder.ShiftRightLogical(a, b); - ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); + ComputeAndCompareR1(&builder, + {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x12345678, 0xF0001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto a = builder.ConstantR1( + {0x12345678, 0xF0001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15, 32, 100, ~0u}); auto out = builder.ShiftLeft(a, b); ComputeAndCompareR1( - &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto a = builder.ConstantR1( + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2, 32, 100, ~0u}); auto out = builder.ShiftRightArithmetic(a, b); - ComputeAndCompareR1(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); + ComputeAndCompareR1( + &builder, {0xF9234567, 0x00100010, 0, 0, 19, 0, ~0u, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); - auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto a = builder.ConstantR1( + {0x92345678, 0x10001000, 1, 3, 77, 1, ~3u, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5, 32, 100, ~0u}); auto out = builder.ShiftRightLogical(a, b); - ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); + ComputeAndCompareR1(&builder, + {0x09234567, 0x00100010, 0, 0, 2, 0, 0, 0}, {}); } XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { @@ -2121,6 +2236,44 @@ XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { + // The input tensor is large enough to exercise the vectorized exp + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr input_literal = Literal::CreateR1( + {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, + -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, + 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, + 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07, + 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09, + 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12, + 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15, + 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17, + 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20, + 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22, + 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25, + 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28, + 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30, + 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, + 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Log(input); + + std::vector expected_result; + int64 input_size = input_literal->shape().dimensions(0); + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(std::log(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 627a9c3e7d9f6eb8d360228362ea5adf12c6c798..3f6fd7c65d3360a622dbf754833009fb20410535 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -62,6 +62,10 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto ax = builder.Mul(alpha, x); auto axpy = builder.Add(ax, y); + TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); + + EXPECT_EQ("() -> f32[10]", ShapeUtil::HumanString(shape)); + std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index f66e3b57bf45fbc9f8ea786146d6fffe5d55a262..59d6d7a4153be1b76ed8195a12a90cb103baa422 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -106,11 +107,108 @@ TEST_F(ConvertTest, ConvertR1F32ToR1S32) { XLA_TEST_F(ConvertTest, ConvertR1S64ToR1F32) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({32, 64}); - builder.ConvertElementType(a, F32); + std::vector arg{ + -9223371216516022272, + -2, + -1, + -0x7FFFFFFF, + -0x80000000, + 0, + 1, + 2, + 1073742145, + 1073742656, + 0x7FFFFFFF, + 0x80000000, + 826720496944058148, + 4296062029846194332, + 0x0007FB72E4000000LL, + 0x0007FB72E4000001LL, + 0x0007FB72E6000000LL, + 0x0007FB72E7000000LL, + 0x0007FB72E7FFFFFFLL, + 0x0007FB72E8000000LL, + 0x0007FB72E8000001LL, + 0x0007FB72EA000000LL, + 0x0007FB72EB000000LL, + 0x0007FB72EBFFFFFFLL, + 0x0007FB72EC000000LL, + 0x7FFFFF0000000000LL, + 0x7FFFFF8000000000LL, + 0x7FFFFFFFFFFFFF00, + static_cast(0xFFFFFFFFFFFFFFFF), + static_cast(0x0000f234e67e0001LL), + static_cast(0x8000000000000000), + static_cast(0x8000000000000000LL), + static_cast(0x8000000000000001LL), + static_cast(0x8000008000000000LL), + static_cast(0x8000010000000000LL), + }; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, F32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} - std::vector expected = {32.0, 64.0}; - ComputeAndCompareR1(&builder, expected, {}); +XLA_TEST_F(ConvertTest, ConvertR1U32ToR1F32) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0, 1, 0x1000, 0x7fffffff, + 0x80000000, 0x80000001, 0x80000002, 0x80000003, + 0x80000080, 0x80000081, 0x80000082, 0xFFFFFFFF}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, F32); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1U32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0, 1, 0x1000, 0x7fffffff, 0x80000082, 0xFFFFFFFF}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1S32ToR1S64) { + ComputationBuilder builder(client_, TestName()); + std::vector arg{0, 1, 0x1000, -1, -0x1000}; + std::unique_ptr arg_literal = Literal::CreateR1({arg}); + auto arg_param = builder.Parameter(0, arg_literal->shape(), "arg_param"); + std::unique_ptr arg_data = + client_->TransferToServer(*arg_literal).ConsumeValueOrDie(); + + builder.ConvertElementType(arg_param, S64); + + std::vector expected(arg.size()); + for (int64 i = 0; i < arg.size(); ++i) { + expected[i] = static_cast(arg[i]); + } + ComputeAndCompareR1(&builder, expected, {arg_data.get()}); } XLA_TEST_F(ConvertTest, ConvertR1U8ToR1F32) { @@ -208,5 +306,65 @@ TEST_F(ConvertTest, ConvertReshape) { ComputeAndCompareR0(&builder, 42.0f, {}, ErrorSpec(0.0001)); } +std::vector GetInterestingF16ConversionTestCases() { + float infinity = std::numeric_limits::infinity(); + float half_min_positive_normal = + tensorflow::bit_cast(0x38800000); + float half_max_subnormal = tensorflow::bit_cast(0x387fc000); + float half_min_positive_subnormal = + tensorflow::bit_cast(0x33800000); + float half_max = 65504.0f; + + std::vector test_cases( + {-infinity, -(half_max * 2 + 1), -half_max, -42.0f, -1.0f, + -half_min_positive_subnormal, -half_max_subnormal, + -half_min_positive_normal, -0.0f, 0.0f, half_min_positive_subnormal, + half_max_subnormal, half_min_positive_normal, 1.0f, 42.0f, half_max, + (half_max * 2 + 1), infinity}); + return test_cases; +} + +XLA_TEST_F(ConvertTest, ConvertR1F16ToR1F32) { + std::vector test_cases = GetInterestingF16ConversionTestCases(); + std::vector input; + c_transform(test_cases, std::back_inserter(input), + [](float f) { return Eigen::half(f); }); + std::vector expected_output; + c_transform(input, std::back_inserter(expected_output), + [](Eigen::half h) { return static_cast(h); }); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dot_lhs_handle, + client_->TransferToServer(*Literal::CreateR1(input))); + + ComputationBuilder builder(client_, TestName()); + builder.ConvertElementType( + builder.Parameter( + 0, ShapeUtil::MakeShape(F16, {static_cast(input.size())}), + "param"), + F32); + + ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); +} + +XLA_TEST_F(ConvertTest, ConvertR1F32ToR1F16) { + std::vector input = GetInterestingF16ConversionTestCases(); + std::vector expected_output; + c_transform(input, std::back_inserter(expected_output), + [](float f) { return Eigen::half(f); }); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr dot_lhs_handle, + client_->TransferToServer(*Literal::CreateR1(input))); + + ComputationBuilder builder(client_, TestName()); + builder.ConvertElementType( + builder.Parameter( + 0, ShapeUtil::MakeShape(F32, {static_cast(input.size())}), + "param"), + F16); + + ComputeAndCompareR1(&builder, expected_output, {dot_lhs_handle.get()}); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 0ceb9aff378ae8aa8098be9360310b1d78d31ab2..e2b5c91653fa6db5df86404c6c5f9158b0d484e1 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -53,157 +53,199 @@ class ConvolutionTest : public ClientLibraryTestBase { #endif }; -XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { - const int kInputActivationSizeY = 3; - const int kInputActivationSizeX = 3; - const int kInputActivationSizeZ = 256; - const int kKernelSizeX = 2; - const int kKernelSizeY = 2; - const int kOutputActivationSizeZ = 256; - const int kMiniBatchSize = 4; - auto alhs = - MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, - kInputActivationSizeY, kInputActivationSizeX); - alhs->FillWithMultiples(1.0f); - ASSERT_EQ(3, alhs->width()); - ASSERT_EQ(3, alhs->height()); - - auto arhs = - MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, - kKernelSizeY, kKernelSizeX); - Array2D rhs_raster({ - {1.0f, 0.0f}, // row 0 - {0.0f, 0.0f}, // row 1 - }); - arhs->FillWithYX(rhs_raster); - ASSERT_EQ(2, arhs->width()); - ASSERT_EQ(2, arhs->height()); +#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) +using TestTypes = ::testing::Types; +#else +using TestTypes = ::testing::Types; +#endif - ComputationBuilder builder(client_, TestName()); - auto lhs = builder.ConstantR4FromArray4D(*alhs); - auto rhs = builder.ConstantR4FromArray4D(*arhs); - auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); +template +Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions); - ComputeAndCompare(&builder, conv, {}, error_spec_); +template <> +Shape MakeShapeWrapper(tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(F32, dimensions); } -TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); +template <> +Shape MakeShapeWrapper( + tensorflow::gtl::ArraySlice dimensions) { + return ShapeUtil::MakeShape(F16, dimensions); +} - Array4D input_data(1, 1, 1, 2); - input_data.FillWithYX(Array2D({ - {1, 2}, - })); - Array4D filter_data(1, 1, 1, 2); - filter_data.FillWithYX(Array2D({ - {5, 6}, - })); +template +class ForwardPassConvolution_3x3x256_256_OutputZ_Iota : public ConvolutionTest { + public: + void RunTest() { + const int kInputActivationSizeY = 3; + const int kInputActivationSizeX = 3; + const int kInputActivationSizeZ = 256; + const int kKernelSizeX = 2; + const int kKernelSizeY = 2; + const int kOutputActivationSizeZ = 256; + const int kMiniBatchSize = 4; + auto alhs = + MakeUnique>(kMiniBatchSize, kInputActivationSizeZ, + kInputActivationSizeY, kInputActivationSizeX); + alhs->FillWithMultiples(static_cast(1.0f)); + ASSERT_EQ(3, alhs->width()); + ASSERT_EQ(3, alhs->height()); + + auto arhs = + MakeUnique>(kOutputActivationSizeZ, kInputActivationSizeZ, + kKernelSizeY, kKernelSizeX); + Array2D rhs_raster({ + {1.0f, 0.0f}, // row 0 + {0.0f, 0.0f}, // row 1 + }); + arhs->FillWithYX(rhs_raster); + ASSERT_EQ(2, arhs->width()); + ASSERT_EQ(2, arhs->height()); + + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR4FromArray4D(*alhs); + auto rhs = builder.ConstantR4FromArray4D(*arhs); + auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + + ComputeAndCompare(&builder, conv, {}, error_spec_); + } +}; - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); +TYPED_TEST_CASE(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, TestTypes); +XLA_TYPED_TEST(ForwardPassConvolution_3x3x256_256_OutputZ_Iota, Types) { + this->RunTest(); } +template +class Convolve_1x1x1x2_1x1x1x2_Valid : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 1, 2}); + Shape filter_shape = MakeShapeWrapper({1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ + {1.0f, 2.0f}, + })); + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ + {5.0f, 6.0f}, + })); + + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve_1x1x1x2_1x1x1x2_Valid, TestTypes); +TYPED_TEST(Convolve_1x1x1x2_1x1x1x2_Valid, Types) { this->RunTest(); } + // Tests valid padding for 2D convolution in raster space. -TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); +template +class Convolve_1x1x4x4_1x1x2x2_Valid : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); + Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 4, 4); + input_data.FillWithYX(Array2D({ + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}, + })); + Array4D filter_data(1, 1, 2, 2); + filter_data.FillWithYX(Array2D({ + {5.0f, 6.0f}, + {7.0f, 8.0f}, + })); + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; - Array4D input_data(1, 1, 4, 4); - // clang-format off - input_data.FillWithYX(Array2D({ - {1, 2, 3, 4 }, - {5, 6, 7, 8 }, - {9, 10, 11, 12}, - {13, 14, 15, 16}, - })); - // clang-format on - Array4D filter_data(1, 1, 2, 2); - // clang-format off - filter_data.FillWithYX(Array2D({ - {5, 6}, - {7, 8}, - })); - // clang-format on - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); -} +TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Valid, TestTypes); +TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Valid, Types) { this->RunTest(); } // Tests same padding for 2D convolution in raster space. -TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - - Array4D input_data(1, 1, 4, 4); - // clang-format off - input_data.FillWithYX(Array2D({ - {1, 2, 3, 4 }, - {5, 6, 7, 8 }, - {9, 10, 11, 12}, - {13, 14, 15, 16}, - })); - // clang-format on - Array4D filter_data(1, 1, 2, 2); - // clang-format off - filter_data.FillWithYX(Array2D({ - {5, 6}, - {7, 8}, - })); - // clang-format on - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); -} +template +class Convolve_1x1x4x4_1x1x2x2_Same : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); + Shape filter_shape = MakeShapeWrapper({1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D input_data(1, 1, 4, 4); + input_data.FillWithYX(Array2D({ + {1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}, + })); + Array4D filter_data(1, 1, 2, 2); + filter_data.FillWithYX(Array2D({ + {5.0f, 6.0f}, + {7.0f, 8.0f}, + })); + + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x2x2_Same, TestTypes); +TYPED_TEST(Convolve_1x1x4x4_1x1x2x2_Same, Types) { this->RunTest(); } // Tests same padding for 2D convolution in raster space with an odd sized // kernel. -TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { - ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); - - Array4D input_data(1, 1, 4, 4); - // clang-format off - input_data.FillWithYX(Array2D({ - {1, 2, 3, 4 }, - {5, 6, 7, 8 }, - {9, 10, 11, 12}, - {13, 14, 15, 16}, - })); - // clang-format on - Array4D filter_data(1, 1, 3, 3); - // clang-format off - filter_data.FillWithYX(Array2D({ - { 5, 6, 7}, - { 8, 9, 10}, - {11, 12, 13}, - })); - // clang-format on - ComputeAndCompare(&builder, conv, - {std::move(*Literal::CreateFromArray(input_data)), - std::move(*Literal::CreateFromArray(filter_data))}, - error_spec_); -} +template +class Convolve_1x1x4x4_1x1x3x3_Same : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = MakeShapeWrapper({1, 1, 4, 4}); + Shape filter_shape = MakeShapeWrapper({1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + + Array4D input_data(1, 1, 4, 4); + input_data.FillWithYX(Array2D({{1.0f, 2.0f, 3.0f, 4.0f}, + {5.0f, 6.0f, 7.0f, 8.0f}, + {9.0f, 10.0f, 11.0f, 12.0f}, + {13.0f, 14.0f, 15.0f, 16.0f}})); + Array4D filter_data(1, 1, 3, 3); + filter_data.FillWithYX(Array2D( + {{5.0f, 6.0f, 7.0f}, {8.0f, 9.0f, 10.0f}, {11.0f, 12.0f, 13.0f}})); + // clang-format on + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); + } +}; + +TYPED_TEST_CASE(Convolve_1x1x4x4_1x1x3x3_Same, TestTypes); +TYPED_TEST(Convolve_1x1x4x4_1x1x3x3_Same, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); @@ -232,36 +274,44 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithRHSDilation) { - ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, - /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); +template +class Convolve1D_1x2x5_1x2x2_WithRHSDilation : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = MakeShapeWrapper({1, 2, 5}); + Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + builder.ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{0, 0}}, + /*lhs_dilation=*/{1}, /*rhs_dilation=*/{2}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input( + {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); + Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); + + Array3D expected({{{570.0f, 670.0f, 770.0f}}}); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; // namespace - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{570, 670, 770}}}); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithRHSDilation, TestTypes); +TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithRHSDilation, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSDilation) { ComputationBuilder builder(client_, TestName()); @@ -325,36 +375,45 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithLHSAndRHSDilation) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_WithPadding) { - ComputationBuilder builder(client_, TestName()); - { - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 2, 5}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - // Convolution dimensions are bf0_oi0->bo0. - builder.ConvGeneralDilated( - input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, - /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, - /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); +template +class Convolve1D_1x2x5_1x2x2_WithPadding : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + { + Shape input_shape = MakeShapeWrapper({1, 2, 5}); + Shape filter_shape = MakeShapeWrapper({1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + // Convolution dimensions are bf0_oi0->bo0. + builder.ConvGeneralDilated( + input, filter, /*window_strides=*/{1}, /*padding=*/{{2, 2}}, + /*lhs_dilation=*/{1}, /*rhs_dilation=*/{1}, + /*dimension_numbers=*/builder.CreateDefaultConvDimensionNumbers(1)); + } + + Array3D input( + {{{1.0f, 2.0f, 3.0f, 4.0f, 5.0f}, {6.0f, 7.0f, 8.0f, 9.0f, 10.0f}}}); + Array3D filter({{{10.0f, 20.0f}, {30.0f, 40.0f}}}); + + Array3D expected( + {{{0.0f, 260.0f, 510.0f, 610.0f, 710.0f, 810.0f, 350.0f, 0.0f}}}); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR3(&builder, expected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; - Array3D input({{{1, 2, 3, 4, 5}, {6, 7, 8, 9, 10}}}); - Array3D filter({{{10, 20}, {30, 40}}}); - - Array3D expected({{{0, 260, 510, 610, 710, 810, 350, 0}}}); - - auto input_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(input)) - .ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*Literal::CreateR3FromArray3D(filter)) - .ConsumeValueOrDie(); - - ComputeAndCompareR3(&builder, expected, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +TYPED_TEST_CASE(Convolve1D_1x2x5_1x2x2_WithPadding, TestTypes); +TYPED_TEST(Convolve1D_1x2x5_1x2x2_WithPadding, Types) { this->RunTest(); } XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { ComputationBuilder builder(client_, TestName()); @@ -389,12 +448,12 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { } std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); - std::iota(input_elems.begin(), input_elems.end(), 1.0f); + iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = Literal::CreateR1(input_elems); auto input_r5 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); - std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); + iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = Literal::CreateR1(filter_elems); auto filter_r5 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); @@ -412,56 +471,73 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { - ComputationBuilder builder(client_, TestName()); - std::vector input_dims = {1, 3, 3, 5}; - std::vector filter_dims = {3, 3, 5, 3}; - Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); - Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); - { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - - // Tensorflow dimension numbers for 2D convolution. - ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(0); - dnums.set_output_batch_dimension(0); - dnums.add_input_spatial_dimensions(1); - dnums.add_output_spatial_dimensions(1); - dnums.add_input_spatial_dimensions(2); - dnums.add_output_spatial_dimensions(2); - dnums.set_input_feature_dimension(3); - dnums.set_output_feature_dimension(3); - dnums.add_kernel_spatial_dimensions(0); - dnums.add_kernel_spatial_dimensions(1); - dnums.set_kernel_input_feature_dimension(2); - dnums.set_kernel_output_feature_dimension(3); +// std::iota doesn't work when init_value has a type Eigen::half in some build +// servers. The error message is missing the operator ++. +template +void iota_int_init_value(std::vector& values, int init_value) { + std::for_each(values.begin(), values.end(), + [&](T& value) { value = static_cast(init_value++); }); +} - builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, - dnums); +template +class Convolve2D_1x3x3x5_3x3x5x5_Valid : public ConvolutionTest { + public: + void RunTest() { + ComputationBuilder builder(client_, TestName()); + std::vector input_dims = {1, 3, 3, 5}; + std::vector filter_dims = {3, 3, 5, 3}; + Shape input_shape = MakeShapeWrapper(input_dims); + Shape filter_shape = MakeShapeWrapper(filter_dims); + { + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 2D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.add_input_spatial_dimensions(2); + dnums.add_output_spatial_dimensions(2); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); + dnums.add_kernel_spatial_dimensions(0); + dnums.add_kernel_spatial_dimensions(1); + dnums.set_kernel_input_feature_dimension(2); + dnums.set_kernel_output_feature_dimension(3); + + builder.ConvWithGeneralDimensions(input, filter, {1, 1}, Padding::kValid, + dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); + iota_int_init_value(input_elems, 1); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); + iota_int_init_value(filter_elems, 1); + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + auto expected_r1 = Literal::CreateR1( + {static_cast(92115), static_cast(93150), static_cast(94185)}); + auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); + + ComputeAndCompareLiteral(&builder, *expected_r4, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; - std::vector input_elems(ShapeUtil::ElementsIn(input_shape)); - std::iota(input_elems.begin(), input_elems.end(), 1.0f); - auto input_r1 = Literal::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); - - std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape)); - std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); - auto filter_r1 = Literal::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - - auto expected_r1 = Literal::CreateR1({92115, 93150, 94185}); - auto expected_r4 = expected_r1->Reshape({1, 1, 1, 3}).ConsumeValueOrDie(); - - auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); - - ComputeAndCompareLiteral(&builder, *expected_r4, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +TYPED_TEST_CASE(Convolve2D_1x3x3x5_3x3x5x5_Valid, TestTypes); +TYPED_TEST(Convolve2D_1x3x3x5_3x3x5x5_Valid, Types) { this->RunTest(); } // Test fixture to run convolution tests with and without convolution // canonicalization enabled. @@ -519,67 +595,78 @@ struct Convolve1DTestParam { int64 num_windows; }; -class Convolve1D1WindowTest +class Convolve1D1WindowTestBase : public ConvolutionTest, - public ::testing::WithParamInterface {}; - -XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { - ComputationBuilder builder(client_, TestName()); - int64 input_feature = GetParam().input_feature; - int64 output_feature = GetParam().output_feature; - int64 batch = GetParam().batch; - int64 num_windows = GetParam().num_windows; - int64 window_size = GetParam().window_size; - std::vector input_dims = {batch, window_size + num_windows - 1, - input_feature}; - std::vector filter_dims = {window_size, input_feature, output_feature}; - Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); - Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); - { - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - - // Tensorflow dimension numbers for 1D convolution. - ConvolutionDimensionNumbers dnums; - dnums.set_input_batch_dimension(0); - dnums.set_output_batch_dimension(0); - dnums.add_input_spatial_dimensions(1); - dnums.add_output_spatial_dimensions(1); - dnums.set_input_feature_dimension(2); - dnums.set_output_feature_dimension(2); - dnums.add_kernel_spatial_dimensions(0); - dnums.set_kernel_input_feature_dimension(1); - dnums.set_kernel_output_feature_dimension(2); - - builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, - dnums); + public ::testing::WithParamInterface { + protected: + template + void TestImpl() { + ComputationBuilder builder(client_, TestName()); + int64 input_feature = GetParam().input_feature; + int64 output_feature = GetParam().output_feature; + int64 batch = GetParam().batch; + int64 num_windows = GetParam().num_windows; + int64 window_size = GetParam().window_size; + std::vector input_dims = {batch, window_size + num_windows - 1, + input_feature}; + std::vector filter_dims = {window_size, input_feature, + output_feature}; + Shape input_shape = MakeShapeWrapper(input_dims); + Shape filter_shape = MakeShapeWrapper(filter_dims); + { + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + + // Tensorflow dimension numbers for 1D convolution. + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.add_output_spatial_dimensions(1); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); + dnums.add_kernel_spatial_dimensions(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(2); + + builder.ConvWithGeneralDimensions(input, filter, {1}, Padding::kValid, + dnums); + } + + std::vector input_elems(ShapeUtil::ElementsIn(input_shape), + static_cast(1.0f)); + auto input_r1 = Literal::CreateR1(input_elems); + auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + + std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), + static_cast(1.0f)); + + auto filter_r1 = Literal::CreateR1(filter_elems); + auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + + std::vector expect_elems(batch * output_feature * num_windows, + static_cast(window_size * input_feature)); + auto expected_r1 = Literal::CreateR1(expect_elems); + auto expected_r3 = + expected_r1->Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); + + auto input_literal = + client_->TransferToServer(*input_r3).ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, *expected_r3, + {input_literal.get(), filter_literal.get()}, + error_spec_); } +}; - std::vector input_elems(ShapeUtil::ElementsIn(input_shape), 1.0); - auto input_r1 = Literal::CreateR1(input_elems); - auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); - - std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0); - - auto filter_r1 = Literal::CreateR1(filter_elems); - auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - - std::vector expect_elems(batch * output_feature * num_windows, - window_size * input_feature); - auto expected_r1 = Literal::CreateR1(expect_elems); - auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) - .ConsumeValueOrDie(); +class Convolve1D1WindowTestFloat : public Convolve1D1WindowTestBase {}; - auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie(); - auto filter_literal = - client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r3, - {input_literal.get(), filter_literal.get()}, - error_spec_); -} +XLA_TEST_P(Convolve1D1WindowTestFloat, Convolve1D1Window) { TestImpl(); } INSTANTIATE_TEST_CASE_P( - Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest, + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestFloat, ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, Convolve1DTestParam{160, 1, 1, 5, 1}, Convolve1DTestParam{24, 1, 1, 20, 1}, @@ -608,6 +695,48 @@ INSTANTIATE_TEST_CASE_P( ); +#if (XLA_TEST_BACKEND_GPU || XLA_TEST_BACKEND_CPU) +class Convolve1D1WindowTestHalf : public Convolve1D1WindowTestBase {}; + +XLA_TEST_P(Convolve1D1WindowTestHalf, Convolve1D1Window) { + TestImpl(); +} + +INSTANTIATE_TEST_CASE_P( + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTestHalf, + ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, + Convolve1DTestParam{160, 1, 1, 5, 1}, + Convolve1DTestParam{24, 1, 1, 20, 1}, + Convolve1DTestParam{30, 1, 1, 20, 1}, + Convolve1DTestParam{23, 1, 1, 20, 20}, + Convolve1DTestParam{25, 1, 1, 20, 1}, + Convolve1DTestParam{24, 1, 1, 10, 5}, + Convolve1DTestParam{160, 1, 1, 10, 1}, + Convolve1DTestParam{255, 1, 1, 3, 1}, + Convolve1DTestParam{130, 1, 1, 1, 3}, + Convolve1DTestParam{64, 1, 1, 1, 1}, + Convolve1DTestParam{128, 1, 1, 1, 1}, +// TODO(b/72566306): The following five tests failed on CPU with unreasonable +// relative errors. Last ran on 2018-02-22. +#if XLA_TEST_BACKEND_GPU + Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, + Convolve1DTestParam{1, 10, 10, 1, 10}, + Convolve1DTestParam{1, 10, 130, 1, 1}, +#endif + Convolve1DTestParam{1, 10, 130, 1, 2}, + Convolve1DTestParam{1, 64, 64, 1, 10}, + Convolve1DTestParam{1, 65, 65, 1, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{128, 128, 128, 128, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{2, 2, 2, 2, 1}, + Convolve1DTestParam{161, 1, 1, 10, 1}) + +); +#endif + TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 6b0c04c2c083bbfce267dd92d24ef15c06186d26..815962094ae476c4b15713ad2c1e4f1e0d140fd9 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -225,33 +225,39 @@ string PrintDotTestParam( } class ParametricDotTest : public DotOperationTest, - public ::testing::WithParamInterface {}; + public ::testing::WithParamInterface { + protected: + template + void TestImpl(); +}; -XLA_TEST_P(ParametricDotTest, TestF32) { +template +void ParametricDotTest::TestImpl() { DotTestParam param = GetParam(); - std::unique_ptr> dot_lhs_data = - MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( *dot_lhs_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); std::unique_ptr dot_lhs_handle = client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); - std::unique_ptr> dot_rhs_data = - MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); - std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( - *dot_rhs_data, LayoutUtil::MakeLayout( - MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); + Layout rhs_layout = LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)); + std::unique_ptr dot_rhs_lit = + Literal::CreateR2FromArray2DWithLayout(*dot_rhs_data, rhs_layout); std::unique_ptr dot_rhs_handle = client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); - std::unique_ptr> addend_data; + std::unique_ptr> addend_data; std::unique_ptr addend_lit; std::unique_ptr addend_handle; if (param.has_addend) { - addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); + addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); addend_lit = Literal::CreateR2FromArray2DWithLayout( *addend_data, LayoutUtil::MakeLayout( MinorToMajorForIsRowMajor(param.addend_row_major))); @@ -259,24 +265,33 @@ XLA_TEST_P(ParametricDotTest, TestF32) { } ComputationBuilder builder(client_, TestName()); - auto prim_type = primitive_util::NativeToPrimitiveType(); + auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), + builder.Parameter(0, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.k}, + MinorToMajorForIsRowMajor(param.dot_lhs_row_major)), "dot_lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), + builder.Parameter(1, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.k, param.n}, + MinorToMajorForIsRowMajor(param.dot_rhs_row_major)), "dot_rhs")); if (param.has_addend) { result = builder.Add( - result, - builder.Parameter( - 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); + result, builder.Parameter( + 2, + ShapeUtil::MakeShapeWithLayout( + prim_type, {param.m, param.n}, + MinorToMajorForIsRowMajor(param.addend_row_major)), + "addend")); } - std::unique_ptr> expected; + std::unique_ptr> expected; if (param.has_addend) { expected = ReferenceUtil::ApplyElementwise2D( - std::plus(), + std::plus(), *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), *addend_data); } else { @@ -288,9 +303,13 @@ XLA_TEST_P(ParametricDotTest, TestF32) { args.push_back(addend_handle.get()); } - ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); + ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); } +XLA_TEST_P(ParametricDotTest, TestF32) { TestImpl(); } + +XLA_TEST_P(ParametricDotTest, TestF64) { TestImpl(); } + std::vector CreateDotTestParameters() { std::vector params; @@ -305,30 +324,79 @@ std::vector CreateDotTestParameters() { } }; + add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); + add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); + add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); + + return params; +} + +INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, + ::testing::ValuesIn(CreateDotTestParameters()), + PrintDotTestParam); + +class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest { + public: + ParametricDotTestWithoutLayoutAssignment() { + execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes( + "layout-assignment"); + } +}; + +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF32) { + TestImpl(); +} + +XLA_TEST_P(ParametricDotTestWithoutLayoutAssignment, TestF64) { + TestImpl(); +} + +std::vector CreateNoLayoutAssignmentDotTestParameters() { + std::vector params; + auto add_matrix_vector_dot_test = [&](int k, int n) { - for (bool has_addend : {false, true}) { - params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, - /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, - /*has_addend=*/has_addend, /*addend_row_major=*/true}); - if (n != 1) { - params.push_back( - {/*m=*/n, /*k=*/k, /*n=*/1, - /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, - /*has_addend=*/has_addend, /*addend_row_major=*/true}); + for (bool lhs_row_major : {true, false}) { + for (bool rhs_row_major : {true, false}) { + for (bool has_addend : {true, false}) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/true}); + if (has_addend) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/false}); + } + if (n != 1) { + params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/true}); + if (has_addend) { + params.push_back({/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/has_addend, + /*addend_row_major=*/false}); + } + } + } } } }; - add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); - add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); - add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); - add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/4); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/3); add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); @@ -339,9 +407,10 @@ std::vector CreateDotTestParameters() { return params; } -INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, - ::testing::ValuesIn(CreateDotTestParameters()), - PrintDotTestParam); +INSTANTIATE_TEST_CASE_P( + DotTests, ParametricDotTestWithoutLayoutAssignment, + ::testing::ValuesIn(CreateNoLayoutAssignmentDotTestParameters()), + PrintDotTestParam); XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { TestSquareMatrixDot(false, false); diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 9f5806c5e16c30cf198027cffab5f78c315cb957..6723c99edb945492abfbac159bed1959d551ec57 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -267,6 +267,28 @@ template reference_preprocessor); } +HloComputation* HloTestBase::FindComputation(HloModule* module, + tensorflow::StringPiece name) { + auto it = c_find_if(module->computations(), + [&](HloComputation* c) { return c->name() == name; }); + if (it == module->computations().end()) { + return nullptr; + } + return *it; +} + +HloInstruction* HloTestBase::FindInstruction(HloModule* module, + tensorflow::StringPiece name) { + for (const HloComputation* c : module->computations()) { + auto it = c_find_if(c->instructions(), + [&](HloInstruction* i) { return i->name() == name; }); + if (it != c->instructions().end()) { + return *it; + } + } + return nullptr; +} + Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 4aea9fc9fd027231106e529eb16bcd43f23fbe1c..413bb213fdcb1303f396308d13d9d0b96b47b71f 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -197,6 +197,15 @@ class HloTestBase : public ::testing::Test { ->Clear(); } + // Gets the computation/instruction from the given module with the given name. + // + // This is useful for tests which create HLOs from a string and then want to + // inspect a particular computation or instruction. + HloComputation* FindComputation(HloModule* module, + tensorflow::StringPiece name); + HloInstruction* FindInstruction(HloModule* module, + tensorflow::StringPiece name); + // Return an HLO verifier constructed for the test backend. HloVerifier& verifier() const { return *hlo_verifier_; } diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 5aa71a9261dbd414d1499f15c9b83cd63b634b49..81630df34c58526b6d41492b2b4b3892a02a21c2 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -209,6 +209,11 @@ template <> return CompareFloatsBitwiseEqual(lhs, rhs); } template <> +::testing::AssertionResult CompareEqual(Eigen::half lhs, + Eigen::half rhs) { + return CompareFloatsBitwiseEqual(lhs, rhs); +} +template <> ::testing::AssertionResult CompareEqual(float lhs, float rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 6e6cb7ff1e2ac74dc54f14d8811c9a5d3662bbd2..0a603f4954badd12adf3144320789a5edd0d9c6c 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -176,5 +178,38 @@ XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } +XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { + const char* testcase = R"( + HloModule m + + fused_computation { + x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) + gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0 + gte.2 = (s32[]) get-tuple-element(gte.3), index=0 + gte.4 = s32[] get-tuple-element(gte.2), index=0 + copy = s32[] copy(gte.4) + ROOT tuple = (s32[]) tuple(copy) + } + + ENTRY thing.v3 { + x = (((s32[]), f32[]), (f32[], s32[])) parameter(0) + ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned( + Literal::MakeTupleOwned( + Literal::MakeTupleOwned(Literal::CreateR0(42)), + Literal::CreateR0(1.0)), + Literal::MakeTupleOwned(Literal::CreateR0(3.0), + Literal::CreateR0(4))); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 4da6ee91607941b395b00befc98a10e7c17746ed..d7bda77e87f33938162f94dbee42b160906b4087 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -163,7 +163,7 @@ XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) { auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a"); builder.ConvertElementType(a, F32); - int64 value = 3LL << 32; + int64 value = 3LL << 35; std::unique_ptr a_literal = Literal::CreateR0(value); std::unique_ptr a_data = client_->TransferToServer(*a_literal).ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index ac163df127e0087c02777fa3d5ce7970c51b97b9..fe36df160daacc4fdfbdb0b75f8304f91e1a4245 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -237,6 +237,12 @@ INSTANTIATE_TEST_CASE_P( SliceR1TestInstantiation, SliceR1Test, ::testing::Values( +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif R1Spec{10, 0, 0, 1}, R1Spec{10, 7, 7, 1}, R1Spec{10, 0, 5, 1}, @@ -267,13 +273,15 @@ INSTANTIATE_TEST_CASE_P( R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, - R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, -// TODO(b/69425338): This uses too much memory on GPU. -#ifndef XLA_TEST_BACKEND_GPU - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, - R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, -#endif + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1} + ), + SliceR1TestDataToString +); + +INSTANTIATE_TEST_CASE_P( + SliceStridedR1TestInstantiation, + SliceR1Test, + ::testing::Values( R1Spec{10, 2, 4, 2}, R1Spec{10, 0, 10, 2}, R1Spec{10, 0, 10, 3}, @@ -285,8 +293,24 @@ INSTANTIATE_TEST_CASE_P( R1Spec{2047, 1024 - 24, 1024 + 160, 31}, R1Spec{2047, 1, 2046, 3 * 128}, R1Spec{4096, 1024 + 3, 4095, 500}, - R1Spec{8192, 0, 8192, 1024 * 3 + 400} - ), + R1Spec{8192, 0, 8192, 1024 * 3 + 400}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 2}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 8}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 7}, + R1Spec{1024 * 1024, 0, 1024 * 1024, 125}, + R1Spec{1024 * 1024, 3, 1024 - 9, 2}, + R1Spec{1024 * 1024, 3, 1024 - 9, 8}, + R1Spec{1024 * 1024, 3, 1024 - 9, 7}, + R1Spec{1024 * 1024, 3, 1024 - 9, 125}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 2}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 8}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 7}, + R1Spec{1024 * 1024, 3, 1024 * 512 - 9, 125}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 2}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 8}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 7}, + R1Spec{1024 * 1024 + 71, 3, 1024 * 512 - 9, 125} + ), SliceR1TestDataToString ); // clang-format on diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index cc4eaf62f50d1fa622c705fab810fe1e1b0fbf08..e2d406f66d94f8ec76faa5b7d2d2e84dcaf6db57 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -161,4 +161,31 @@ string PrependDisabledIfIndicated(const string& test_case_name, #define XLA_TEST_P(test_case_name, test_name) \ XLA_TEST_P_IMPL_(test_case_name, test_name) + +// This is identical to the TEST_F macro from "gtest", but it potentially +// disables the test based on an external manifest file, DISABLED_MANIFEST. +#define XLA_TYPED_TEST(CaseName, TestName) \ + template \ + class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ + : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + virtual void TestBody(); \ + }; \ + bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTest< \ + CaseName, \ + ::testing::internal::TemplateSel, \ + GTEST_TYPE_PARAMS_(CaseName)>:: \ + Register( \ + "", ::testing::internal::CodeLocation(__FILE__, __LINE__), \ + #CaseName, \ + ::xla::PrependDisabledIfIndicated(#CaseName, #TestName).c_str(), \ + 0); \ + template \ + void GTEST_TEST_CLASS_NAME_(CaseName, \ + TestName)::TestBody() + #endif // TENSORFLOW_COMPILER_XLA_TESTS_TEST_MACROS_H_ diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index b060fb13b1451aab30cfca73bea0a4a598a9fa3a..0bc7df2a65b44a76f877b6513e6bf93b99fbc1a3 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -287,7 +287,7 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { StatusOr>> MakeFakeArguments( HloModule* const module) { - TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(*module)); const auto params = module->entry_computation()->parameter_instructions(); std::minstd_rand0 engine; std::vector> arguments(params.size()); diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index a8bca70d85ddf168bc441231d6f43bead019b10a..2029312f94a14bc81706368b9ecfc2727fd9fe4c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -194,8 +194,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. -XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 89def5d5610cb9522a69297668b443b8c4e03fb5..e60a5a4919f2207939821e787c3c59a08ff3ba4e 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -994,6 +994,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands, *custom_call_target)); break; } + case HloOpcode::kHostCompute: { + optional channel_name; + optional cost_estimate_ns; + attrs["channel_name"] = {/*required=*/true, AttrTy::kString, + &channel_name}; + attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, + &cost_estimate_ns}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( + shape, operands, *channel_name, *cost_estimate_ns)); + break; + } case HloOpcode::kDot: { optional> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { @@ -1035,6 +1049,40 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } + case HloOpcode::kGather: { + optional> output_window_dims; + attrs["output_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &output_window_dims}; + optional> elided_window_dims; + attrs["elided_window_dims"] = { + /*required=*/true, AttrTy::kBracedInt64List, &elided_window_dims}; + optional> gather_dims_to_operand_dims; + attrs["gather_dims_to_operand_dims"] = {/*required=*/true, + AttrTy::kBracedInt64List, + &gather_dims_to_operand_dims}; + optional index_vector_dim; + attrs["index_vector_dim"] = {/*required=*/true, AttrTy::kInt64, + &index_vector_dim}; + optional> window_bounds; + attrs["window_bounds"] = {/*required=*/true, AttrTy::kBracedInt64List, + &window_bounds}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + GatherDimensionNumbers dim_numbers = HloInstruction::MakeGatherDimNumbers( + /*output_window_dims=*/*output_window_dims, + /*elided_window_dims=*/*elided_window_dims, + /*gather_dims_to_operand_dims=*/*gather_dims_to_operand_dims, + /*index_vector_dim=*/*index_vector_dim); + + instruction = builder->AddInstruction(HloInstruction::CreateGather( + shape, /*operand=*/operands[0], /*gather_indices=*/operands[1], + dim_numbers, *window_bounds)); + break; + } case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index b8c6b59204f897c7dc07b846370b5b776a19a808..863081d654390440aa6506bab4576b3cc5c1cbd1 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -716,6 +716,18 @@ ENTRY %sparse_f32_r1 () -> f32[9] { ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) } +)" +}, +{ +"gather", +R"(HloModule StringifyGather + +ENTRY %Gather (input_tensor: f32[50,49,48,47,46], gather_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] { + %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + %gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); @@ -860,6 +872,18 @@ ENTRY dot { ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} } +)" +}, +{ +"gather", +R"(HloModule gather + +ENTRY Gather { + input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0) + gather_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1) + ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, gather_indices), output_window_dims={4,5,6,7,8}, elided_window_dims={}, gather_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, window_bounds={30,29,28,27,26} +} + )" }, }); diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 08df5b12b3a53a138f56705531baa3333b23c5d8..e14c8cefa1d16e0a749e7a2c022a24a1c5083b15 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -427,8 +427,9 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name); template -bool c_all_of(Container container, Predicate predicate) { - return std::all_of(std::begin(container), std::end(container), predicate); +bool c_all_of(Container container, Predicate&& predicate) { + return std::all_of(std::begin(container), std::end(container), + std::forward(predicate)); } template +OutputIterator c_copy(InputContainer input_container, + OutputIterator output_iterator) { + return std::copy(std::begin(input_container), std::end(input_container), + output_iterator); +} + +template +void c_sort(InputContainer& input_container) { + std::sort(std::begin(input_container), std::end(input_container)); +} + template -void c_sort(InputContainer& input_container, Comparator comparator) { - std::sort(input_container.begin(), input_container.end(), comparator); +void c_sort(InputContainer& input_container, Comparator&& comparator) { + std::sort(std::begin(input_container), std::end(input_container), + std::forward(comparator)); +} + +template +bool c_binary_search(Sequence& sequence, T&& value) { + return std::binary_search(std::begin(sequence), std::end(sequence), + std::forward(value)); } +template +bool c_is_sorted(const C& c) { + return std::is_sorted(std::begin(c), std::end(c)); +} + +template +auto c_adjacent_find(const C& c) -> decltype(std::begin(c)) { + return std::adjacent_find(std::begin(c), std::end(c)); +} + +template +auto c_find_if(const C& c, Pred&& pred) -> decltype(std::begin(c)) { + return std::find_if(std::begin(c), std::end(c), std::forward(pred)); +} } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 3aea0217539b89b5d60ecfaf2605eee4b69af728..1f16e6d25178fd9c10a30b0c500e090ee2e08117 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -393,6 +393,37 @@ message Window { repeated WindowDimension dimensions = 1; } +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the gather_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in elided_window_dims + // then 0 + // else Out[output_window_dims[i++]] + repeated int64 output_window_dims = 1; + repeated int64 elided_window_dims = 2; + + // This is interpreted as a map from i to gather_dims_to_operand_dims[i]. It + // transforms the gather index looked up from the gather_indices tensor into + // the starting index in the input space. + repeated int64 gather_dims_to_operand_dims = 3; + + // The dimension in the gather_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; +} + // Operation requests that are all collected as a tagged union with a oneof // field in OpRequest. @@ -519,6 +550,20 @@ message CustomCallRequest { Shape shape = 4; } +message HostComputeRequest { + // Operand to the HostCompute. Supports tuple. + repeated ComputationDataHandle operands = 1; + + // Name used to identify HostSend/Recv channels. + string channel_name = 2; + + // Cost estimate in nanoseconds. + int64 cost_estimate_ns = 3; + + // The shape of any data returned by host. + Shape shape = 4; +} + message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -880,6 +925,13 @@ message RecvRequest { ChannelHandle channel_handle = 2; } +message GatherRequest { + ComputationDataHandle input = 1; + ComputationDataHandle gather_indices = 2; + GatherDimensionNumbers dimension_numbers = 3; + repeated int64 window_bounds = 4; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal, @@ -957,7 +1009,9 @@ message OpRequest { FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; ConditionalRequest conditional_request = 44; - // Next: 45 + HostComputeRequest host_compute_request = 45; + GatherRequest gather_request = 46; + // Next: 47 } } diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index b8d73bf24ce60e0b3850d4f39ac9e6d6c2194a02..db37bcf73d144eb81c32a461a276d10be7e2d193 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -81,6 +81,11 @@ For documentation on building a self-contained AAR file with cmake, see [tensorflow/contrib/android/cmake](cmake). +### Makefile + +For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md) + + ## AssetManagerFileSystem This directory also contains a TensorFlow filesystem supporting the Android diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index 74712aeb67c3f0a31def78f25a0298f9c02c9590..270c309ec3f2f8337f2c079decff0b4eeefee234 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -39,7 +39,7 @@ py_library( cuda_py_test( name = "metropolis_hastings_test", - size = "medium", + size = "large", srcs = ["python/kernel_tests/metropolis_hastings_test.py"], additional_deps = [ ":bayesflow_py", @@ -99,6 +99,16 @@ cuda_py_test( ], ) +cuda_py_test( + name = "docstring_util_test", + size = "small", + srcs = ["python/kernel_tests/docstring_util_test.py"], + additional_deps = [ + ":bayesflow_py", + "//tensorflow/python:client_testlib", + ], +) + cuda_py_test( name = "layers_conv_variational_test", size = "small", @@ -200,7 +210,7 @@ cuda_py_test( cuda_py_test( name = "hmc_test", - size = "medium", + size = "large", srcs = ["python/kernel_tests/hmc_test.py"], additional_deps = [ ":bayesflow_py", diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed500b19d8dd72795758a2920119e3680576697 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/docstring_util_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for docstring utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import docstring_util +from tensorflow.python.platform import test + + +class DocstringUtil(test.TestCase): + + def _testFunction(self): + doc_args = """x: Input to return as output. + y: Baz.""" + @docstring_util.expand_docstring(args=doc_args) + def foo(x): + # pylint: disable=g-doc-args + """Hello world. + + Args: + @{args} + + Returns: + x. + """ + # pylint: enable=g-doc-args + return x + + true_docstring = """Hello world. + + Args: + x: Input to return as output. + y: Baz. + + Returns: + x. + """ + self.assertEqual(foo.__doc__, true_docstring) + + def _testClassInit(self): + doc_args = """x: Input to return as output. + y: Baz.""" + + class Foo(object): + + @docstring_util.expand_docstring(args=doc_args) + def __init__(self, x, y): + # pylint: disable=g-doc-args + """Hello world. + + Args: + @{args} + + Bar. + """ + # pylint: enable=g-doc-args + pass + + true_docstring = """Hello world. + + Args: + x: Input to return as output. + y: Baz. + + Bar. + """ + self.assertEqual(Foo.__doc__, true_docstring) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index 5bd834e56245ab4d874544cfd014fe59ae521ea8..819095a060b5f4cf18df6e7e4e4556e50ae44dd3 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -224,12 +224,13 @@ class HMCTest(test.TestCase): expected_exp_x = self._shape_param / self._rate_param - acceptance_probs_, samples_, expected_x_ = sess.run( - [kernel_results.acceptance_probs, samples, expected_x], + log_accept_ratio_, samples_, expected_x_ = sess.run( + [kernel_results.log_accept_ratio, samples, expected_x], feed_dict) actual_x = samples_.mean() actual_exp_x = np.exp(samples_).mean() + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( expected_x_, expected_exp_x)) @@ -237,10 +238,10 @@ class HMCTest(test.TestCase): actual_x, actual_exp_x)) self.assertNear(actual_x, expected_x_, 2e-2) self.assertNear(actual_exp_x, expected_exp_x, 2e-2) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ > 0.5) - self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), - acceptance_probs_ <= 1.) + self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), + acceptance_probs > 0.5) + self.assertAllEqual(np.ones_like(acceptance_probs, np.bool), + acceptance_probs <= 1.) def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): with self.test_session(graph=ops.Graph()) as sess: @@ -265,7 +266,7 @@ class HMCTest(test.TestCase): -x - x**2, # Non-constant gradient. array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) # This log_prob has the property that it is likely to attract - # the HMC flow toward, and below, zero...but for x <=0, + # the flow toward, and below, zero...but for x <=0, # log_prob(x) = -inf, which should result in rejection, as well # as a non-finite log_prob. Thus, this distribution gives us an opportunity # to test out the kernel results ability to correctly capture rejections due @@ -305,11 +306,10 @@ class HMCTest(test.TestCase): self.assertLess(0, neg_inf_mask.sum()) # We better have some rejections due to something other than -inf. self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) - # We better have been accepted a decent amount, even near the end of the - # chain, or else this HMC run just got stuck at some point. + # We better have accepted a decent amount, even near end of the chain. self.assertLess( 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) - # We better not have any NaNs in proposed state or log_prob. + # We better not have any NaNs in states or log_prob. # We may have some NaN in grads, which involve multiplication/addition due # to gradient rules. This is the known "NaN grad issue with tf.where." self.assertAllEqual(np.zeros_like(states_), @@ -333,9 +333,11 @@ class HMCTest(test.TestCase): np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) # Acceptance probs are zero whenever proposed state is negative. + acceptance_probs = np.exp(np.minimum( + kernel_results_.log_accept_ratio, 0.)) self.assertAllEqual( np.zeros_like(pstates_[neg_inf_mask]), - kernel_results_.acceptance_probs[neg_inf_mask]) + acceptance_probs[neg_inf_mask]) # The move is accepted ==> state = proposed state. self.assertAllEqual( @@ -383,26 +385,28 @@ class HMCTest(test.TestCase): seed=44) [ - acceptance_probs_, - bad_acceptance_probs_, + log_accept_ratio_, + bad_log_accept_ratio_, initial_draws_, updated_draws_, fake_draws_, ] = sess.run([ - kernel_results.acceptance_probs, - bad_kernel_results.acceptance_probs, + kernel_results.log_accept_ratio, + bad_kernel_results.log_accept_ratio, initial_draws, sample, bad_sample, ], feed_dict) # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_.mean(), 0.5) + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) + bad_acceptance_probs = np.exp(np.minimum(bad_log_accept_ratio_, 0.)) + self.assertGreater(acceptance_probs.mean(), 0.5) + self.assertGreater(bad_acceptance_probs.mean(), 0.5) # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_.mean(), 0.99) - self.assertLess(bad_acceptance_probs_.mean(), 0.99) + self.assertLess(acceptance_probs.mean(), 0.99) + self.assertLess(bad_acceptance_probs.mean(), 0.99) _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), updated_draws_.flatten()) @@ -410,9 +414,9 @@ class HMCTest(test.TestCase): fake_draws_.flatten()) logging_ops.vlog(1, "acceptance rate for true target: {}".format( - acceptance_probs_.mean())) + acceptance_probs.mean())) logging_ops.vlog(1, "acceptance rate for fake target: {}".format( - bad_acceptance_probs_.mean())) + bad_acceptance_probs.mean())) logging_ops.vlog(1, "K-S p-value for true target: {}".format( ks_p_value_true)) logging_ops.vlog(1, "K-S p-value for fake target: {}".format( @@ -615,15 +619,16 @@ class HMCTest(test.TestCase): step_size=2., num_leapfrog_steps=5, seed=46) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) + initial_x_, updated_x_, log_accept_ratio_ = sess.run( + [initial_x, updated_x, kernel_results.log_accept_ratio]) + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) + self.assertEqual(acceptance_probs, 0.) def testNanFromGradsDontPropagate(self): """Test that update with NaN gradients does not cause NaN in results.""" @@ -638,15 +643,16 @@ class HMCTest(test.TestCase): step_size=2., num_leapfrog_steps=5, seed=47) - initial_x_, updated_x_, acceptance_probs_ = sess.run( - [initial_x, updated_x, kernel_results.acceptance_probs]) + initial_x_, updated_x_, log_accept_ratio_ = sess.run( + [initial_x, updated_x, kernel_results.log_accept_ratio]) + acceptance_probs = np.exp(np.minimum(log_accept_ratio_, 0.)) logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) - logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + logging_ops.vlog(1, "log_accept_ratio = {}".format(log_accept_ratio_)) self.assertAllEqual(initial_x_, updated_x_) - self.assertEqual(acceptance_probs_, 0.) + self.assertEqual(acceptance_probs, 0.) self.assertAllFinite( gradients_ops.gradients(updated_x, initial_x)[0].eval()) @@ -671,10 +677,10 @@ class HMCTest(test.TestCase): step_size=0.01, num_leapfrog_steps=10, seed=48) - states_, acceptance_probs_ = sess.run( - [states, kernel_results.acceptance_probs]) + states_, log_accept_ratio_ = sess.run( + [states, kernel_results.log_accept_ratio]) self.assertEqual(dtype, states_.dtype) - self.assertEqual(dtype, acceptance_probs_.dtype) + self.assertEqual(dtype, log_accept_ratio_.dtype) def testChainWorksIn64Bit(self): self._testChainWorksDtype(np.float64) diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py index d68fc9081ac78135feab6af6e6318a27ea0a00af..52e36e135d95c1ec919c710f35d59073c2134d05 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py @@ -41,12 +41,14 @@ class _EffectiveSampleSizeTest(object): sess, atol=1e-2, rtol=1e-2, - max_lags_threshold=None, - max_lags=None): + filter_threshold=None, + filter_beyond_lag=None): x = array_ops.placeholder_with_default( input=x_, shape=x_.shape if self.use_static_shape else None) ess = mcmc_diagnostics.effective_sample_size( - x, max_lags_threshold=max_lags_threshold, max_lags=max_lags) + x, + filter_threshold=filter_threshold, + filter_beyond_lag=filter_beyond_lag) if self.use_static_shape: self.assertAllEqual(x.shape[1:], ess.shape) @@ -56,18 +58,19 @@ class _EffectiveSampleSizeTest(object): np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) def testIidRank1NormalHasFullEssMaxLags10(self): - # With a length 5000 iid normal sequence, and max_lags = 10, we should - # have a good estimate of ESS, and it should be close to the full sequence - # length of 5000. - # The choice of max_lags = 10 is a short cutoff, reasonable only since we - # know the correlation length should be zero right away. + # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we + # should have a good estimate of ESS, and it should be close to the full + # sequence length of 5000. + # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only + # since we know the correlation length should be zero right away. with self.test_session() as sess: with spectral_ops_test_util.fft_kernel_label_map(): self._check_versus_expected_effective_sample_size( x_=rng.randn(5000).astype(np.float32), expected_ess=5000, sess=sess, - max_lags=10, + filter_beyond_lag=10, + filter_threshold=None, rtol=0.3) def testIidRank2NormalHasFullEssMaxLags10(self): @@ -78,23 +81,25 @@ class _EffectiveSampleSizeTest(object): x_=rng.randn(5000, 2).astype(np.float32), expected_ess=5000, sess=sess, - max_lags=10, + filter_beyond_lag=10, + filter_threshold=None, rtol=0.3) def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): - # With a length 5000 iid normal sequence, and max_lags_threshold = 0, + # With a length 5000 iid normal sequence, and filter_threshold = 0, # we should have a super-duper estimate of ESS, and it should be very close # to the full sequence length of 5000. - # The choice of max_lags_cutoff = 0 means we cutoff as soon as the auto-corr - # is below zero. This should happen very quickly, due to the fact that the - # theoretical auto-corr is [1, 0, 0,...] + # The choice of filter_beyond_lag = 0 means we cutoff as soon as the + # auto-corris below zero. This should happen very quickly, due to the fact + # that the theoretical auto-corr is [1, 0, 0,...] with self.test_session() as sess: with spectral_ops_test_util.fft_kernel_label_map(): self._check_versus_expected_effective_sample_size( x_=rng.randn(5000).astype(np.float32), expected_ess=5000, sess=sess, - max_lags_threshold=0., + filter_beyond_lag=None, + filter_threshold=0., rtol=0.1) def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): @@ -105,7 +110,8 @@ class _EffectiveSampleSizeTest(object): x_=rng.randn(5000, 2).astype(np.float32), expected_ess=5000, sess=sess, - max_lags_threshold=0., + filter_beyond_lag=None, + filter_threshold=0., rtol=0.1) def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): @@ -121,7 +127,8 @@ class _EffectiveSampleSizeTest(object): x_=x_, expected_ess=50000 // 10, sess=sess, - max_lags=50, + filter_beyond_lag=50, + filter_threshold=None, rtol=0.2) def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( @@ -138,7 +145,8 @@ class _EffectiveSampleSizeTest(object): x_=x_, expected_ess=50000 // 10, sess=sess, - max_lags_threshold=0., + filter_beyond_lag=None, + filter_threshold=0., rtol=0.1) def testListArgs(self): @@ -148,16 +156,16 @@ class _EffectiveSampleSizeTest(object): x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) y_ = rng.randn(50000).astype(np.float32) states = [x_, x_, y_, y_] - max_lags_threshold = [0., None, 0., None] - max_lags = [None, 5, None, 5] + filter_threshold = [0., None, 0., None] + filter_beyond_lag = [None, 5, None, 5] # See other tests for reasoning on tolerance. with self.test_session() as sess: with spectral_ops_test_util.fft_kernel_label_map(): ess = mcmc_diagnostics.effective_sample_size( states, - max_lags_threshold=max_lags_threshold, - max_lags=max_lags) + filter_threshold=filter_threshold, + filter_beyond_lag=filter_beyond_lag) ess_ = sess.run(ess) self.assertAllEqual(4, len(ess_)) @@ -166,6 +174,59 @@ class _EffectiveSampleSizeTest(object): self.assertAllClose(50000, ess_[2], rtol=0.1) self.assertAllClose(50000, ess_[3], rtol=0.1) + def testMaxLagsThresholdLessThanNeg1SameAsNone(self): + # Setting both means we filter out items R_k from the auto-correlation + # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. + + # x_ has correlation length 10. + iid_x_ = rng.randn(500, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + + ess_none_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=None, filter_beyond_lag=None) + ess_none_200 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=None, filter_beyond_lag=200) + ess_neg2_200 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=-2., filter_beyond_lag=200) + ess_neg2_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=-2., filter_beyond_lag=None) + ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run( + [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none]) + + # filter_threshold=-2 <==> filter_threshold=None. + self.assertAllClose(ess_none_none_, ess_neg2_none_) + self.assertAllClose(ess_none_200_, ess_neg2_200_) + + def testMaxLagsArgsAddInAnOrManner(self): + # Setting both means we filter out items R_k from the auto-correlation + # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. + + # x_ has correlation length 10. + iid_x_ = rng.randn(500, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + + ess_1_9 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=9) + ess_1_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=None) + ess_none_9 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=9) + ess_1_9_, ess_1_none_, ess_none_9_ = sess.run( + [ess_1_9, ess_1_none, ess_none_9]) + + # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10, + # filter_threshold = 1 <==> filter_beyond_lag = 9. + self.assertAllClose(ess_1_9_, ess_1_none_) + self.assertAllClose(ess_1_9_, ess_none_9_) + class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py index 63d93fad64d077aa385b72428665e841b6784b90..f508e5b114a55fc1aeb07212595fda45fc308c7b 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py @@ -12,34 +12,195 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for metropolis_hastings.py.""" +"""Tests for Metropolis-Hastings.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np + from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh +from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -class McmcStepTest(test.TestCase): +class MetropolisHastingsTest(test.TestCase): + + def testKernelStateTensor(self): + """Test that transition kernel works with tensor input to `state`.""" + loc = variable_scope.get_variable("loc", initializer=0.) + + def target_log_prob_fn(loc): + return normal_lib.Normal(loc=0.0, scale=0.1).log_prob(loc) + + new_state, _ = mh.kernel( + target_log_prob_fn=target_log_prob_fn, + proposal_fn=mh.proposal_normal(scale=0.05), + current_state=loc, + seed=231251) + loc_update = loc.assign(new_state) + + init = variables.initialize_all_variables() + with self.test_session() as sess: + sess.run(init) + loc_samples = [] + for _ in range(2500): + loc_sample = sess.run(loc_update) + loc_samples.append(loc_sample) + loc_samples = loc_samples[500:] # drop samples for burn-in + + self.assertAllClose(np.mean(loc_samples), 0.0, rtol=1e-5, atol=1e-1) + self.assertAllClose(np.std(loc_samples), 0.1, rtol=1e-5, atol=1e-1) + + def testKernelStateList(self): + """Test that transition kernel works with list input to `state`.""" + num_chains = 2 + loc_one = variable_scope.get_variable( + "loc_one", [num_chains], + initializer=init_ops.zeros_initializer()) + loc_two = variable_scope.get_variable( + "loc_two", [num_chains], initializer=init_ops.zeros_initializer()) + + def target_log_prob_fn(loc_one, loc_two): + loc = array_ops.stack([loc_one, loc_two]) + log_prob = mvn_tril_lib.MultivariateNormalTriL( + loc=constant_op.constant([0., 0.]), + scale_tril=constant_op.constant([[0.1, 0.1], [0.0, 0.1]])).log_prob( + loc) + return math_ops.reduce_sum(log_prob, 0) + + def proposal_fn(loc_one, loc_two): + loc_one_proposal = mh.proposal_normal(scale=0.05) + loc_two_proposal = mh.proposal_normal(scale=0.05) + loc_one_sample, _ = loc_one_proposal(loc_one) + loc_two_sample, _ = loc_two_proposal(loc_two) + return [loc_one_sample, loc_two_sample], None + + new_state, _ = mh.kernel( + target_log_prob_fn=target_log_prob_fn, + proposal_fn=proposal_fn, + current_state=[loc_one, loc_two], + seed=12415) + loc_one_update = loc_one.assign(new_state[0]) + loc_two_update = loc_two.assign(new_state[1]) + + init = variables.initialize_all_variables() + with self.test_session() as sess: + sess.run(init) + loc_one_samples = [] + loc_two_samples = [] + for _ in range(10000): + loc_one_sample, loc_two_sample = sess.run( + [loc_one_update, loc_two_update]) + loc_one_samples.append(loc_one_sample) + loc_two_samples.append(loc_two_sample) + + loc_one_samples = np.array(loc_one_samples) + loc_two_samples = np.array(loc_two_samples) + loc_one_samples = loc_one_samples[1000:] # drop samples for burn-in + loc_two_samples = loc_two_samples[1000:] # drop samples for burn-in + + self.assertAllClose(np.mean(loc_one_samples, 0), + np.array([0.] * num_chains), + rtol=1e-5, atol=1e-1) + self.assertAllClose(np.mean(loc_two_samples, 0), + np.array([0.] * num_chains), + rtol=1e-5, atol=1e-1) + self.assertAllClose(np.std(loc_one_samples, 0), + np.array([0.1] * num_chains), + rtol=1e-5, atol=1e-1) + self.assertAllClose(np.std(loc_two_samples, 0), + np.array([0.1] * num_chains), + rtol=1e-5, atol=1e-1) + + def testKernelResultsUsingTruncatedDistribution(self): + def log_prob(x): + return array_ops.where( + x >= 0., + -x - x**2, + array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) + # The truncated distribution has the property that it is likely to attract + # the flow toward, and below, zero...but for x <=0, + # log_prob(x) = -inf, which should result in rejection, as well + # as a non-finite log_prob. Thus, this distribution gives us an opportunity + # to test out the kernel results ability to correctly capture rejections due + # to finite AND non-finite reasons. + + num_results = 1000 + # Large step size, will give rejections due to going into a region of + # log_prob = -inf. + step_size = 0.3 + num_chains = 2 + + with self.test_session(graph=ops.Graph()) as sess: + + # Start multiple independent chains. + initial_state = ops.convert_to_tensor([0.1] * num_chains) - def test_density_increasing_step_accepted(self): + states = [] + is_accepted = [] + proposed_states = [] + current_state = initial_state + for _ in range(num_results): + current_state, kernel_results = mh.kernel( + target_log_prob_fn=log_prob, + proposal_fn=mh.proposal_uniform(step_size=step_size), + current_state=current_state, + seed=42) + states.append(current_state) + proposed_states.append(kernel_results.proposed_state) + is_accepted.append(kernel_results.is_accepted) + + states = array_ops.stack(states) + proposed_states = array_ops.stack(proposed_states) + is_accepted = array_ops.stack(is_accepted) + states_, pstates_, is_accepted_ = sess.run( + [states, proposed_states, is_accepted]) + + # We better have accepted a decent amount, even near end of the chain. + self.assertLess( + 0.1, is_accepted_[int(0.9 * num_results):].mean()) + # We better not have any NaNs in states. + self.assertAllEqual(np.zeros_like(states_), + np.isnan(states_)) + # We better not have any +inf in states. + self.assertAllEqual(np.zeros_like(states_), + np.isposinf(states_)) + + # The move is accepted ==> state = proposed state. + self.assertAllEqual( + states_[is_accepted_], + pstates_[is_accepted_], + ) + + # The move was rejected <==> state[t] == state[t - 1]. + for t in range(1, num_results): + for i in range(num_chains): + if is_accepted_[t, i]: + self.assertNotEqual(states_[t, i], states_[t - 1, i]) + else: + self.assertEqual(states_[t, i], states_[t - 1, i]) + + def testDensityIncreasingStepAccepted(self): """Tests that if a transition increases density, it is always accepted.""" target_log_density = lambda x: - x * x - state = variable_scope.get_variable('state', initializer=10.) + state = variable_scope.get_variable("state", initializer=10.) state_log_density = variable_scope.get_variable( - 'state_log_density', + "state_log_density", initializer=target_log_density(state.initialized_value())) log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) + "log_accept_ratio", initializer=0.) get_next_proposal = lambda x: (x - 1., None) step = mh.evolve(state, state_log_density, log_accept_ratio, @@ -54,7 +215,7 @@ class McmcStepTest(test.TestCase): self.assertAlmostEqual(sample, 9 - j) self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) - def test_sample_properties(self): + def testSampleProperties(self): """Tests that the samples converge to the target distribution.""" def target_log_density(x): @@ -62,16 +223,16 @@ class McmcStepTest(test.TestCase): return - (x - 2.0) * (x - 2.0) * 0.5 # Use the uniform random walker to generate proposals. - proposal_fn = mh.uniform_random_proposal( + proposal_fn = mh.proposal_uniform( step_size=1.0, seed=1234) - state = variable_scope.get_variable('state', initializer=0.0) + state = variable_scope.get_variable("state", initializer=0.0) state_log_density = variable_scope.get_variable( - 'state_log_density', + "state_log_density", initializer=target_log_density(state.initialized_value())) - log_accept_ratio = variable_scope.get_variable( - 'log_accept_ratio', initializer=0.) + "log_accept_ratio", initializer=0.) + # Random walk MCMC converges slowly so need to put in enough iterations. num_iterations = 5000 step = mh.evolve(state, state_log_density, log_accept_ratio, @@ -98,11 +259,11 @@ class McmcStepTest(test.TestCase): self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) - def test_normal_proposals(self): + def testProposalNormal(self): """Tests that the normal proposals are correctly distributed.""" initial_points = array_ops.ones([10000], dtype=dtypes.float32) - proposal_fn = mh.normal_random_proposal( + proposal_fn = mh.proposal_normal( scale=2.0, seed=1234) proposal_points, _ = proposal_fn(initial_points) @@ -115,7 +276,7 @@ class McmcStepTest(test.TestCase): self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) - def test_docstring_example(self): + def testDocstringExample(self): """Tests the simplified docstring example with multiple chains.""" n = 2 # dimension of the problem @@ -123,7 +284,7 @@ class McmcStepTest(test.TestCase): # Generate 300 initial values randomly. Each of these would be an # independent starting point for a Markov chain. state = variable_scope.get_variable( - 'state', initializer=random_ops.random_normal( + "state", initializer=random_ops.random_normal( [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) # Computes the log(p(x)) for the unit normal density and ignores the @@ -133,12 +294,12 @@ class McmcStepTest(test.TestCase): # Initial log-density value state_log_density = variable_scope.get_variable( - 'state_log_density', + "state_log_density", initializer=log_density(state.initialized_value())) # A variable to store the log_acceptance_ratio: log_acceptance_ratio = variable_scope.get_variable( - 'log_acceptance_ratio', + "log_acceptance_ratio", initializer=array_ops.zeros([300], dtype=dtypes.float32)) # Generates random proposals by moving each coordinate uniformly and @@ -175,5 +336,5 @@ class McmcStepTest(test.TestCase): - np.reshape(covariance, [n**2]))), 0, delta=0.2) -if __name__ == '__main__': +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/docstring_util.py b/tensorflow/contrib/bayesflow/python/ops/docstring_util.py new file mode 100644 index 0000000000000000000000000000000000000000..081f2d5a8bfd437fd173f63b4226fb7df6ca921c --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/docstring_util.py @@ -0,0 +1,88 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for programmable docstrings. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import six + + +def expand_docstring(**kwargs): + """Decorator to programmatically expand the docstring. + + Args: + **kwargs: Keyword arguments to set. For each key-value pair `k` and `v`, + the key is found as `@{k}` in the docstring and replaced with `v`. + + Returns: + Decorated function. + """ + def _fn_wrapped(fn): + """Original function with modified `__doc__` attribute.""" + doc = _trim(fn.__doc__) + for k, v in six.iteritems(kwargs): + # Capture each @{k} reference to replace with v. + # We wrap the replacement in a function so no backslash escapes + # are processed. + pattern = r'@\{' + str(k) + r'\}' + doc = re.sub(pattern, lambda match: v, doc) # pylint: disable=cell-var-from-loop + fn.__doc__ = doc + return fn + return _fn_wrapped + + +def _trim(docstring): + """Trims docstring indentation. + + In general, multi-line docstrings carry their level of indentation when + defined under a function or class method. This function standardizes + indentation levels by removing them. Taken from PEP 257 docs. + + Args: + docstring: Python string to trim indentation. + + Returns: + Trimmed docstring. + """ + if not docstring: + return '' + # Convert tabs to spaces (following the normal Python rules) + # and split into a list of lines: + lines = docstring.expandtabs().splitlines() + # Determine minimum indentation (first line doesn't count): + indent = None + for line in lines[1:]: + stripped = line.lstrip() + if stripped: + if indent is None: + indent = len(line) - len(stripped) + else: + indent = min(indent, len(line) - len(stripped)) + # Remove indentation (first line is special): + trimmed = [lines[0].strip()] + if indent is not None: + for line in lines[1:]: + trimmed.append(line[indent:].rstrip()) + # Strip off trailing and leading blank lines: + while trimmed and not trimmed[-1]: + trimmed.pop() + while trimmed and not trimmed[0]: + trimmed.pop(0) + # Return a single string: + return '\n'.join(trimmed) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index f724910c59315867a42a56fab3deb36f5d3adb7a..82693c2b7bcdbca9f6f4a1d799be5728bb5d36bf 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -46,15 +46,13 @@ __all__ = [ KernelResults = collections.namedtuple( "KernelResults", [ - "acceptance_probs", + "log_accept_ratio", "current_grads_target_log_prob", # "Current result" means "accepted". "current_target_log_prob", # "Current result" means "accepted". - "energy_change", "is_accepted", "proposed_grads_target_log_prob", "proposed_state", "proposed_target_log_prob", - "random_positive", ]) @@ -63,15 +61,13 @@ def _make_dummy_kernel_results( dummy_target_log_prob, dummy_grads_target_log_prob): return KernelResults( - acceptance_probs=dummy_target_log_prob, + log_accept_ratio=dummy_target_log_prob, current_grads_target_log_prob=dummy_grads_target_log_prob, current_target_log_prob=dummy_target_log_prob, - energy_change=dummy_target_log_prob, is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), proposed_grads_target_log_prob=dummy_grads_target_log_prob, proposed_state=dummy_state, proposed_target_log_prob=dummy_target_log_prob, - random_positive=dummy_target_log_prob, ) @@ -109,10 +105,13 @@ def sample_chain( Note: `target_log_prob_fn` is called exactly twice. - Only one out of every `num_steps_between_samples + 1` steps is included in the - returned results. This "thinning" comes at a cost of reduced statistical - power, while reducing memory requirements and autocorrelation. For more - discussion see [1]. + Since HMC states are correlated, it is sometimes desirable to produce + additional intermediate states, and then discard them, ending up with a set of + states with decreased autocorrelation. See [1]. Such "thinning" is made + possible by setting `num_steps_between_results > 0`. The chain then takes + `num_steps_between_results` extra steps between the steps that make it into + the results. The extra steps are never materialized (in calls to `sess.run`), + and thus do not increase memory requirements. [1]: "Statistically efficient thinning of a Markov chain sampler." Art B. Owen. April 2017. @@ -225,10 +224,8 @@ def sample_chain( Default value: 0 (i.e., no burn-in). num_steps_between_results: Integer number of chain steps between collecting a result. Only one out of every `num_steps_between_samples + 1` steps is - included in the returned results. This "thinning" comes at a cost of - reduced statistical power, while reducing memory requirements and - autocorrelation. For more discussion see [1]. - Default value: 0 (i.e., no subsampling). + included in the returned results. The number of returned chain states is + still equal to `num_results`. Default value: 0 (i.e., no thinning). seed: Python integer to seed the random number generator. current_target_log_prob: (Optional) `Tensor` representing the value of `target_log_prob_fn` at the `current_state`. The only reason to specify @@ -243,7 +240,7 @@ def sample_chain( Default value: `None` (i.e., "hmc_sample_chain"). Returns: - accepted_states: Tensor or Python list of `Tensor`s representing the + next_states: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state` but with a prepended `num_results`-size dimension. kernel_results: `collections.namedtuple` of internal calculations used to @@ -469,7 +466,7 @@ def sample_annealed_importance_chain( Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). Returns: - accepted_state: `Tensor` or Python list of `Tensor`s representing the + next_state: `Tensor` or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at the final iteration. Has same shape as input `current_state`. ais_weights: Tensor with the estimated weight(s). Has shape matching @@ -590,18 +587,19 @@ def kernel(target_log_prob_fn, target = tfd.Normal(loc=dtype(0), scale=dtype(1)) - new_x, other_results = hmc.kernel( + next_x, other_results = hmc.kernel( target_log_prob_fn=target.log_prob, current_state=x, step_size=step_size, num_leapfrog_steps=3)[:4] - x_update = x.assign(new_x) + x_update = x.assign(next_x) step_size_update = step_size.assign_add( step_size * tf.where( - other_results.acceptance_probs > target_accept_rate, - 0.01, -0.01)) + tf.exp(tf.minimum(other_results.log_accept_ratio), 0.) > + target_accept_rate, + 0.01, -0.01)) warmup = tf.group([x_update, step_size_update]) @@ -752,7 +750,7 @@ def kernel(target_log_prob_fn, Default value: `None` (i.e., "hmc_kernel"). Returns: - accepted_state: Tensor or Python list of `Tensor`s representing the state(s) + next_state: Tensor or Python list of `Tensor`s representing the state(s) of the Markov chain(s) at each result step. Has same shape as `current_state`. kernel_results: `collections.namedtuple` of internal calculations used to @@ -805,30 +803,27 @@ def kernel(target_log_prob_fn, proposed_target_log_prob, proposed_momentums, independent_chain_ndims) + log_accept_ratio = -energy_change - # u < exp(min(-energy, 0)), where u~Uniform[0,1) - # ==> -log(u) >= max(e, 0) - # ==> -log(u) >= e - # (Perhaps surprisingly, we don't have a better way to obtain a random - # uniform from positive reals, i.e., `tf.random_uniform(minval=0, - # maxval=np.inf)` won't work.) - random_uniform = random_ops.random_uniform( + # u < exp(log_accept_ratio), where u~Uniform[0,1) + # ==> log(u) < log_accept_ratio + random_value = random_ops.random_uniform( shape=array_ops.shape(energy_change), dtype=energy_change.dtype, seed=seed) - random_positive = -math_ops.log(random_uniform) - is_accepted = random_positive >= energy_change + random_negative = math_ops.log(random_value) + is_accepted = random_negative < log_accept_ratio accepted_target_log_prob = array_ops.where(is_accepted, proposed_target_log_prob, current_target_log_prob) - accepted_state_parts = [_choose(is_accepted, - proposed_state_part, - current_state_part, - independent_chain_ndims) - for current_state_part, proposed_state_part - in zip(current_state_parts, proposed_state_parts)] + next_state_parts = [_choose(is_accepted, + proposed_state_part, + current_state_part, + independent_chain_ndims) + for current_state_part, proposed_state_part + in zip(current_state_parts, proposed_state_parts)] accepted_grads_target_log_prob = [ _choose(is_accepted, @@ -840,17 +835,15 @@ def kernel(target_log_prob_fn, maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] return [ - maybe_flatten(accepted_state_parts), + maybe_flatten(next_state_parts), KernelResults( - acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), + log_accept_ratio=log_accept_ratio, current_grads_target_log_prob=accepted_grads_target_log_prob, current_target_log_prob=accepted_target_log_prob, - energy_change=energy_change, is_accepted=is_accepted, proposed_grads_target_log_prob=proposed_grads_target_log_prob, proposed_state=maybe_flatten(proposed_state_parts), proposed_target_log_prob=proposed_target_log_prob, - random_positive=random_positive, ), ] @@ -882,8 +875,8 @@ def _leapfrog_integrator(current_momentums, momentum = tf.placeholder(np.float32) [ - new_momentums, - new_positions, + next_momentums, + next_positions, ] = hmc._leapfrog_integrator( current_momentums=[momentum], target_log_prob_fn=tfd.MultivariateNormalDiag( @@ -900,7 +893,7 @@ def _leapfrog_integrator(current_momentums, positions = np.zeros([num_iter, dims], dtype) for i in xrange(num_iter): position_, momentum_ = sess.run( - [new_momentums[0], new_position[0]], + [next_momentums[0], next_position[0]], feed_dict={position: position_, momentum: momentum_}) positions[i] = position_ @@ -943,9 +936,9 @@ def _leapfrog_integrator(current_momentums, state(s) of the Markov chain(s) at each result step. Has same shape as input `current_state_parts`. proposed_target_log_prob: `Tensor` representing the value of - `target_log_prob_fn` at `accepted_state`. + `target_log_prob_fn` at `next_state`. proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt - `accepted_state`. + `next_state`. Raises: ValueError: if `len(momentums) != len(state_parts)`. @@ -1065,8 +1058,8 @@ def _compute_energy_change(current_target_log_prob, axis=-1) lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), axis=-1) - lp0 = -current_target_log_prob # log_potential - lp1 = -proposed_target_log_prob # proposed_log_potential + lp0 = -current_target_log_prob # potential + lp1 = -proposed_target_log_prob # proposed_potential x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], axis=-1) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py index 7723cfb442712626ff415f1412e3362f2392ce9f..cb80718f719ff31fb8ba5066170342fc69630780 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.bayesflow.python.ops import docstring_util from tensorflow.contrib.bayesflow.python.ops import layers_util from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.python.framework import dtypes @@ -34,6 +35,45 @@ from tensorflow.python.ops.distributions import kullback_leibler as kl_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import util as distribution_util +doc_args = """activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: A string, the name of the layer.""" + class _ConvVariational(layers_lib.Layer): """Abstract nD convolution layer (private, used as implementation base). @@ -55,65 +95,6 @@ class _ConvVariational(layers_lib.Layer): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer. - Properties: rank: Python integer, dimensionality of convolution. filters: Python integer, dimensionality of the output space. @@ -134,6 +115,7 @@ class _ConvVariational(layers_lib.Layer): bias_divergence_fn: `callable` returning divergence. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, rank, @@ -157,6 +139,33 @@ class _ConvVariational(layers_lib.Layer): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + rank: An integer, the rank of the convolution, e.g. "2" for 2D + convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, ..., + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(_ConvVariational, self).__init__( trainable=trainable, name=name, @@ -371,65 +380,6 @@ class _ConvReparameterization(_ConvVariational): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: A string, the name of the layer. - Properties: rank: Python integer, dimensionality of convolution. filters: Python integer, dimensionality of the output space. @@ -454,6 +404,7 @@ class _ConvReparameterization(_ConvVariational): International Conference on Learning Representations, 2014. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, rank, @@ -477,6 +428,33 @@ class _ConvReparameterization(_ConvVariational): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + rank: An integer, the rank of the convolution, e.g. "2" for 2D + convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, ..., + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(_ConvReparameterization, self).__init__( rank=rank, filters=filters, @@ -529,63 +507,6 @@ class Conv1DReparameterization(_ConvReparameterization): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - Properties: filters: Python integer, dimensionality of the output space. kernel_size: Size of the convolution window. @@ -639,6 +560,7 @@ class Conv1DReparameterization(_ConvReparameterization): International Conference on Learning Representations, 2014. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, filters, @@ -661,6 +583,31 @@ class Conv1DReparameterization(_ConvReparameterization): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, length, + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(Conv1DReparameterization, self).__init__( rank=1, filters=filters, @@ -683,6 +630,7 @@ class Conv1DReparameterization(_ConvReparameterization): name=name, **kwargs) +@docstring_util.expand_docstring(args=doc_args) def conv1d_reparameterization( inputs, filters, @@ -705,6 +653,7 @@ def conv1d_reparameterization( bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, reuse=None): + # pylint: disable=g-doc-args """Functional interface for 1D convolution layer (e.g. temporal convolution). This layer creates a convolution kernel that is convolved @@ -726,7 +675,7 @@ def conv1d_reparameterization( (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: + Args: inputs: Tensor input. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). @@ -746,43 +695,7 @@ def conv1d_reparameterization( the dilation rate to use for dilated convolution. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. + @{args} reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -827,6 +740,7 @@ def conv1d_reparameterization( Diederik P. Kingma, Max Welling. International Conference on Learning Representations, 2014. """ + # pylint: enable=g-doc-args layer = Conv1DReparameterization( filters=filters, kernel_size=kernel_size, @@ -874,70 +788,6 @@ class Conv2DReparameterization(_ConvReparameterization): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - Properties: filters: Python integer, dimensionality of the output space. kernel_size: Size of the convolution window. @@ -994,6 +844,7 @@ class Conv2DReparameterization(_ConvReparameterization): International Conference on Learning Representations, 2014. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, filters, @@ -1016,6 +867,37 @@ class Conv2DReparameterization(_ConvReparameterization): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, height, + width, channels)` while `channels_first` corresponds to inputs with + shape `(batch, channels, height, width)`. + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(Conv2DReparameterization, self).__init__( rank=2, filters=filters, @@ -1038,6 +920,7 @@ class Conv2DReparameterization(_ConvReparameterization): name=name, **kwargs) +@docstring_util.expand_docstring(args=doc_args) def conv2d_reparameterization( inputs, filters, @@ -1060,6 +943,7 @@ def conv2d_reparameterization( bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, reuse=None): + # pylint: disable=g-doc-args """Functional interface for the 2D convolution layer. This layer creates a convolution kernel that is convolved @@ -1081,7 +965,7 @@ def conv2d_reparameterization( (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: + Args: inputs: Tensor input. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). @@ -1101,50 +985,13 @@ def conv2d_reparameterization( `channels_last` corresponds to inputs with shape `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape `(batch, channels, height, width)`. - dilation_rate: An integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. + @{args} reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -1193,6 +1040,7 @@ def conv2d_reparameterization( Diederik P. Kingma, Max Welling. International Conference on Learning Representations, 2014. """ + # pylint: enable=g-doc-args layer = Conv2DReparameterization( filters=filters, kernel_size=kernel_size, @@ -1240,71 +1088,6 @@ class Conv3DReparameterization(_ConvReparameterization): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. - Properties: filters: Python integer, dimensionality of the output space. kernel_size: Size of the convolution window. @@ -1361,6 +1144,7 @@ class Conv3DReparameterization(_ConvReparameterization): International Conference on Learning Representations, 2014. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, filters, @@ -1383,6 +1167,38 @@ class Conv3DReparameterization(_ConvReparameterization): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, depth, + height, width, channels)` while `channels_first` corresponds to inputs + with shape `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(Conv3DReparameterization, self).__init__( rank=3, filters=filters, @@ -1405,6 +1221,7 @@ class Conv3DReparameterization(_ConvReparameterization): name=name, **kwargs) +@docstring_util.expand_docstring(args=doc_args) def conv3d_reparameterization( inputs, filters, @@ -1427,6 +1244,7 @@ def conv3d_reparameterization( bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, reuse=None): + # pylint: disable=g-doc-args """Functional interface for the 3D convolution layer. This layer creates a convolution kernel that is convolved @@ -1448,7 +1266,7 @@ def conv3d_reparameterization( (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: + Args: inputs: Tensor input. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). @@ -1476,43 +1294,7 @@ def conv3d_reparameterization( all spatial dimensions. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - name: A string, the name of the layer. + @{args} reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -1561,6 +1343,7 @@ def conv3d_reparameterization( Diederik P. Kingma, Max Welling. International Conference on Learning Representations, 2014. """ + # pylint: enable=g-doc-args layer = Conv3DReparameterization( filters=filters, kernel_size=kernel_size, @@ -1611,67 +1394,6 @@ class _ConvFlipout(_ConvVariational): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of n integers, specifying the - length of the convolution window. - strides: An integer or tuple/list of n integers, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, ..., channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, ...)`. - dilation_rate: An integer or tuple/list of n integers, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - Properties: rank: Python integer, dimensionality of convolution. filters: Python integer, dimensionality of the output space. @@ -1694,10 +1416,11 @@ class _ConvFlipout(_ConvVariational): [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, rank, @@ -1722,6 +1445,33 @@ class _ConvFlipout(_ConvVariational): seed=None, name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + rank: An integer, the rank of the convolution, e.g. "2" for 2D + convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, ..., + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(_ConvFlipout, self).__init__( rank=rank, filters=filters, @@ -1822,65 +1572,6 @@ class Conv1DFlipout(_ConvFlipout): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of a single integer, specifying the - length of the 1D convolution window. - strides: An integer or tuple/list of a single integer, - specifying the stride length of the convolution. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, length, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, length)`. - dilation_rate: An integer or tuple/list of a single integer, specifying - the dilation rate to use for dilated convolution. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - Properties: filters: Python integer, dimensionality of the output space. kernel_size: Size of the convolution window. @@ -1932,10 +1623,11 @@ class Conv1DFlipout(_ConvFlipout): [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, filters, @@ -1959,6 +1651,31 @@ class Conv1DFlipout(_ConvFlipout): seed=None, name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, length, + channels)` while `channels_first` corresponds to inputs with shape + `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(Conv1DFlipout, self).__init__( rank=1, filters=filters, @@ -1982,6 +1699,7 @@ class Conv1DFlipout(_ConvFlipout): name=name, **kwargs) +@docstring_util.expand_docstring(args=doc_args) def conv1d_flipout( inputs, filters, @@ -2005,6 +1723,7 @@ def conv1d_flipout( seed=None, name=None, reuse=None): + # pylint: disable=g-doc-args """Functional interface for 1D convolution layer (e.g. temporal convolution). This layer creates a convolution kernel that is convolved @@ -2029,7 +1748,7 @@ def conv1d_flipout( (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: + Args: inputs: Tensor input. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). @@ -2049,45 +1768,7 @@ def conv1d_flipout( the dilation rate to use for dilated convolution. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any `strides` value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. + @{args} reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -2130,9 +1811,10 @@ def conv1d_flipout( [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + # pylint: enable=g-doc-args layer = Conv1DFlipout( filters=filters, kernel_size=kernel_size, @@ -2184,72 +1866,6 @@ class Conv2DFlipout(_ConvFlipout): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 2 integers, specifying the - height and width of the 2D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 2 integers, - specifying the strides of the convolution along the height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, height, width, channels)` while `channels_first` corresponds to - inputs with shape `(batch, channels, height, width)`. - - dilation_rate: An integer or tuple/list of 2 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - Properties: filters: Python integer, dimensionality of the output space. kernel_size: Size of the convolution window. @@ -2304,10 +1920,11 @@ class Conv2DFlipout(_ConvFlipout): [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, filters, @@ -2331,6 +1948,37 @@ class Conv2DFlipout(_ConvFlipout): seed=None, name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, height, + width, channels)` while `channels_first` corresponds to inputs with + shape `(batch, channels, height, width)`. + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(Conv2DFlipout, self).__init__( rank=2, filters=filters, @@ -2354,6 +2002,7 @@ class Conv2DFlipout(_ConvFlipout): name=name, **kwargs) +@docstring_util.expand_docstring(args=doc_args) def conv2d_flipout( inputs, filters, @@ -2377,6 +2026,7 @@ def conv2d_flipout( seed=None, name=None, reuse=None): + # pylint: disable=g-doc-args """Functional interface for the 2D convolution layer. This layer creates a convolution kernel that is convolved @@ -2401,7 +2051,7 @@ def conv2d_flipout( (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: + Args: inputs: Tensor input. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). @@ -2421,52 +2071,13 @@ def conv2d_flipout( `channels_last` corresponds to inputs with shape `(batch, height, width, channels)` while `channels_first` corresponds to inputs with shape `(batch, channels, height, width)`. - dilation_rate: An integer or tuple/list of 2 integers, specifying the dilation rate to use for dilated convolution. Can be a single integer to specify the same value for all spatial dimensions. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. + @{args} reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -2513,9 +2124,10 @@ def conv2d_flipout( [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + # pylint: enable=g-doc-args layer = Conv2DFlipout( filters=filters, kernel_size=kernel_size, @@ -2567,73 +2179,6 @@ class Conv3DFlipout(_ConvFlipout): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: - filters: Integer, the dimensionality of the output space (i.e. the number - of filters in the convolution). - kernel_size: An integer or tuple/list of 3 integers, specifying the - depth, height and width of the 3D convolution window. - Can be a single integer to specify the same value for - all spatial dimensions. - strides: An integer or tuple/list of 3 integers, - specifying the strides of the convolution along the depth, - height and width. - Can be a single integer to specify the same value for - all spatial dimensions. - Specifying any stride value != 1 is incompatible with specifying - any `dilation_rate` value != 1. - padding: One of `"valid"` or `"same"` (case-insensitive). - data_format: A string, one of `channels_last` (default) or `channels_first`. - The ordering of the dimensions in the inputs. - `channels_last` corresponds to inputs with shape - `(batch, depth, height, width, channels)` while `channels_first` - corresponds to inputs with shape - `(batch, channels, depth, height, width)`. - dilation_rate: An integer or tuple/list of 3 integers, specifying - the dilation rate to use for dilated convolution. - Can be a single integer to specify the same value for - all spatial dimensions. - Currently, specifying any `dilation_rate` value != 1 is - incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. - Properties: filters: Python integer, dimensionality of the output space. kernel_size: Size of the convolution window. @@ -2688,10 +2233,11 @@ class Conv3DFlipout(_ConvFlipout): [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, filters, @@ -2715,6 +2261,38 @@ class Conv3DFlipout(_ConvFlipout): seed=None, name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or + `channels_first`. The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape `(batch, depth, + height, width, channels)` while `channels_first` corresponds to inputs + with shape `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + @{args} + """ + # pylint: enable=g-doc-args super(Conv3DFlipout, self).__init__( rank=3, filters=filters, @@ -2738,6 +2316,7 @@ class Conv3DFlipout(_ConvFlipout): name=name, **kwargs) +@docstring_util.expand_docstring(args=doc_args) def conv3d_flipout( inputs, filters, @@ -2761,6 +2340,7 @@ def conv3d_flipout( seed=None, name=None, reuse=None): + # pylint: disable=g-doc-args """Functional interface for the 3D convolution layer. This layer creates a convolution kernel that is convolved @@ -2785,7 +2365,7 @@ def conv3d_flipout( (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Arguments: + Args: inputs: Tensor input. filters: Integer, the dimensionality of the output space (i.e. the number of filters in the convolution). @@ -2813,45 +2393,7 @@ def conv3d_flipout( all spatial dimensions. Currently, specifying any `dilation_rate` value != 1 is incompatible with specifying any stride value != 1. - activation: Activation function. Set it to None to maintain a - linear activation. - activity_regularizer: Optional regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: A string, the name of the layer. + @{args} reuse: Boolean, whether to reuse the weights of a previous layer by the same name. @@ -2898,9 +2440,10 @@ def conv3d_flipout( [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on Mini-Batches." - Anonymous. OpenReview, 2017. - https://openreview.net/forum?id=rJnpifWAb + Yeming Wen, Paul Vicol, Jimmy Ba, Dustin Tran, Roger Grosse. + International Conference on Learning Representations, 2018. """ + # pylint: enable=g-doc-args layer = Conv3DFlipout( filters=filters, kernel_size=kernel_size, diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py index 591a8e553de0c194786c7ee8693665f762711b2d..1f1d8fda2a5db4db33a2b6e5d7f027c4b509011a 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.bayesflow.python.ops import docstring_util from tensorflow.contrib.bayesflow.python.ops import layers_util from tensorflow.contrib.distributions.python.ops import independent as independent_lib from tensorflow.python.framework import dtypes @@ -33,6 +34,53 @@ from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.distributions import util as distribution_util +doc_args = """units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name.""" + + class _DenseVariational(layers_lib.Layer): """Abstract densely-connected class (private, used as implementation base). @@ -50,51 +98,6 @@ class _DenseVariational(layers_lib.Layer): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - Properties: units: Python integer, dimensionality of the output space. activation: Activation function (`callable`). @@ -109,6 +112,7 @@ class _DenseVariational(layers_lib.Layer): bias_divergence_fn: `callable` returning divergence. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, units, @@ -126,6 +130,13 @@ class _DenseVariational(layers_lib.Layer): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + @{args} + """ + # pylint: enable=g-doc-args super(_DenseVariational, self).__init__( trainable=trainable, name=name, @@ -274,51 +285,6 @@ class DenseReparameterization(_DenseVariational): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - Properties: units: Python integer, dimensionality of the output space. activation: Activation function (`callable`). @@ -363,6 +329,7 @@ class DenseReparameterization(_DenseVariational): International Conference on Learning Representations, 2014. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, units, @@ -381,6 +348,13 @@ class DenseReparameterization(_DenseVariational): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + @{args} + """ + # pylint: enable=g-doc-args super(DenseReparameterization, self).__init__( units=units, activation=activation, @@ -405,6 +379,7 @@ class DenseReparameterization(_DenseVariational): return self._matmul(inputs, self.kernel_posterior_tensor) +@docstring_util.expand_docstring(args=doc_args) def dense_reparameterization( inputs, units, @@ -422,6 +397,7 @@ def dense_reparameterization( bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, reuse=None): + # pylint: disable=g-doc-args """Densely-connected layer with reparameterization estimator. This layer implements the Bayesian variational inference analogue to @@ -444,49 +420,7 @@ def dense_reparameterization( Args: inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. + @{args} Returns: output: `Tensor` representing a the affine transformed input under a random @@ -522,6 +456,7 @@ def dense_reparameterization( Diederik P. Kingma, Max Welling. International Conference on Learning Representations, 2014. """ + # pylint: enable=g-doc-args layer = DenseReparameterization( units, activation=activation, @@ -563,51 +498,6 @@ class DenseLocalReparameterization(_DenseVariational): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - Properties: units: Python integer, dimensionality of the output space. activation: Activation function (`callable`). @@ -652,6 +542,7 @@ class DenseLocalReparameterization(_DenseVariational): Neural Information Processing Systems, 2015. """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, units, @@ -670,6 +561,13 @@ class DenseLocalReparameterization(_DenseVariational): bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + @{args} + """ + # pylint: enable=g-doc-args super(DenseLocalReparameterization, self).__init__( units=units, activation=activation, @@ -705,6 +603,7 @@ class DenseLocalReparameterization(_DenseVariational): return self.kernel_posterior_affine_tensor +@docstring_util.expand_docstring(args=doc_args) def dense_local_reparameterization( inputs, units, @@ -723,6 +622,7 @@ def dense_local_reparameterization( bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), name=None, reuse=None): + # pylint: disable=g-doc-args """Densely-connected layer with local reparameterization estimator. This layer implements the Bayesian variational inference analogue to @@ -745,49 +645,7 @@ def dense_local_reparameterization( Args: inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. + @{args} Returns: output: `Tensor` representing a the affine transformed input under a random @@ -823,6 +681,7 @@ def dense_local_reparameterization( Diederik P. Kingma, Tim Salimans, Max Welling. Neural Information Processing Systems, 2015. """ + # pylint: enable=g-doc-args layer = DenseLocalReparameterization( units, activation=activation, @@ -866,53 +725,6 @@ class DenseFlipout(_DenseVariational): (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` distributions. - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - Properties: units: Python integer, dimensionality of the output space. activation: Activation function (`callable`). @@ -959,6 +771,7 @@ class DenseFlipout(_DenseVariational): https://openreview.net/forum?id=rJnpifWAb """ + @docstring_util.expand_docstring(args=doc_args) def __init__( self, units, @@ -978,6 +791,13 @@ class DenseFlipout(_DenseVariational): seed=None, name=None, **kwargs): + # pylint: disable=g-doc-args + """Construct layer. + + Args: + @{args} + """ + # pylint: enable=g-doc-args super(DenseFlipout, self).__init__( units=units, activation=activation, @@ -1031,6 +851,7 @@ class DenseFlipout(_DenseVariational): return outputs +@docstring_util.expand_docstring(args=doc_args) def dense_flipout( inputs, units, @@ -1050,6 +871,7 @@ def dense_flipout( seed=None, name=None, reuse=None): + # pylint: disable=g-doc-args """Densely-connected layer with Flipout estimator. This layer implements the Bayesian variational inference analogue to @@ -1074,51 +896,7 @@ def dense_flipout( Args: inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - seed: Python scalar `int` which initializes the random number - generator. Default value: `None` (i.e., use global seed). - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. + @{args} Returns: output: `Tensor` representing a the affine transformed input under a random @@ -1155,6 +933,7 @@ def dense_flipout( Anonymous. OpenReview, 2017. https://openreview.net/forum?id=rJnpifWAb """ + # pylint: enable=g-doc-args layer = DenseFlipout( units, activation=activation, diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py index bb8b915a9b540e06899837b2394004119d4ce715..0424b6952bc89ce7fe5b00b0135c9a5fe1faa8cf 100644 --- a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py @@ -36,13 +36,13 @@ __all__ = [ def effective_sample_size(states, - max_lags_threshold=None, - max_lags=None, + filter_threshold=0., + filter_beyond_lag=None, name=None): """Estimate a lower bound on effective sample size for each independent chain. - Roughly speaking, the "effective sample size" (ESS) is the size of an iid - sample with the same variance as `state`. + Roughly speaking, "effective sample size" (ESS) is the size of an iid sample + with the same variance as `state`. More precisely, given a stationary sequence of possibly correlated random variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number @@ -87,21 +87,28 @@ def effective_sample_size(states, This function estimates the above by first estimating the auto-correlation. Since `R_k` must be estimated using only `N - k` samples, it becomes progressively noisier for larger `k`. For this reason, the summation over - `R_k` should be truncated at some number `max_lags < N`. Since many MCMC - methods generate chains where `R_k > 0`, a reasonable critera is to truncate - at the first index where the estimated auto-correlation becomes negative. + `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many + MCMC methods generate chains where `R_k > 0`, a reasonable critera is to + truncate at the first index where the estimated auto-correlation becomes + negative. + + The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to + remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning + terms are removed if they were to be filtered under the `filter_beyond_lag` OR + `filter_threshold` criteria. Args: states: `Tensor` or list of `Tensor` objects. Dimension zero should index identically distributed states. - max_lags_threshold: `Tensor` or list of `Tensor` objects. + filter_threshold: `Tensor` or list of `Tensor` objects. Must broadcast with `state`. The auto-correlation sequence is truncated - after the first appearance of a term less than `max_lags_threshold`. If - both `max_lags` and `max_lags_threshold` are `None`, - `max_lags_threshold` defaults to `0`. - max_lags: `Tensor` or list of `Tensor` objects. Must be `int`-like and - scalar valued. The auto-correlation sequence is truncated to this length. - May be provided only if `max_lags_threshold` is not. + after the first appearance of a term less than `filter_threshold`. + Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, + setting to any number less than `-1` has the same effect. + filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be + `int`-like and scalar valued. The auto-correlation sequence is truncated + to this length. Setting to `None` means we do not filter based on number + of lags. name: `String` name to prepend to created ops. Returns: @@ -109,8 +116,8 @@ def effective_sample_size(states, each component of `states`. Shape will be `states.shape[1:]`. Raises: - ValueError: If `states` and `max_lags_threshold` or `states` and `max_lags` - are both lists with different lengths. + ValueError: If `states` and `filter_threshold` or `states` and + `filter_beyond_lag` are both lists with different lengths. """ states_was_list = _is_list_like(states) @@ -118,15 +125,16 @@ def effective_sample_size(states, if not states_was_list: states = [states] - max_lags = _broadcast_maybelist_arg(states, max_lags, "max_lags") - max_lags_threshold = _broadcast_maybelist_arg(states, max_lags_threshold, - "max_lags_threshold") + filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag, + "filter_beyond_lag") + filter_threshold = _broadcast_maybelist_arg(states, filter_threshold, + "filter_threshold") # Process items, one at a time. with ops.name_scope(name, "effective_sample_size"): ess_list = [ _effective_sample_size_single_state(s, ml, mlt) - for (s, ml, mlt) in zip(states, max_lags, max_lags_threshold) + for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold) ] if states_was_list: @@ -134,38 +142,31 @@ def effective_sample_size(states, return ess_list[0] -def _effective_sample_size_single_state(states, max_lags, max_lags_threshold): +def _effective_sample_size_single_state(states, filter_beyond_lag, + filter_threshold): """ESS computation for one single Tensor argument.""" - if max_lags is not None and max_lags_threshold is not None: - raise ValueError( - "Expected at most one of max_lags, max_lags_threshold to be provided. " - "Found: {}, {}".format(max_lags, max_lags_threshold)) - - if max_lags_threshold is None: - max_lags_threshold = 0. with ops.name_scope( "effective_sample_size_single_state", - values=[states, max_lags, max_lags_threshold]): + values=[states, filter_beyond_lag, filter_threshold]): states = ops.convert_to_tensor(states, name="states") dt = states.dtype - if max_lags is not None: - auto_corr = sample_stats.auto_correlation( - states, axis=0, max_lags=max_lags) - elif max_lags_threshold is not None: - max_lags_threshold = ops.convert_to_tensor( - max_lags_threshold, dtype=dt, name="max_lags_threshold") - auto_corr = sample_stats.auto_correlation(states, axis=0) + # filter_beyond_lag == None ==> auto_corr is the full sequence. + auto_corr = sample_stats.auto_correlation( + states, axis=0, max_lags=filter_beyond_lag) + if filter_threshold is not None: + filter_threshold = ops.convert_to_tensor( + filter_threshold, dtype=dt, name="filter_threshold") # Get a binary mask to zero out values of auto_corr below the threshold. # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, # mask[i, ...] = 0, otherwise. # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] # Building step by step, - # Assume auto_corr = [1, 0.5, 0.0, 0.3], and max_lags_threshold = 0.2. + # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. # Step 1: mask = [False, False, True, False] - mask = auto_corr < max_lags_threshold + mask = auto_corr < filter_threshold # Step 2: mask = [0, 0, 1, 1] mask = math_ops.cast(mask, dtype=dt) # Step 3: mask = [0, 0, 1, 2] @@ -173,14 +174,12 @@ def _effective_sample_size_single_state(states, max_lags, max_lags_threshold): # Step 4: mask = [1, 1, 0, 0] mask = math_ops.maximum(1. - mask, 0.) auto_corr *= mask - else: - auto_corr = sample_stats.auto_correlation(states, axis=0) # With R[k] := auto_corr[k, ...], # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} - #, where M is the max_lags truncation point chosen above. + # where M is the filter_beyond_lag truncation point chosen above. # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total # ndims the same as auto_corr diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py index 7bdeaa862d5bb64fa8940df453c7aa2b66023eda..e7fcbc65ef379e84a140a06e020549f74f905a99 100644 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py +++ b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py @@ -25,9 +25,10 @@ from tensorflow.contrib.bayesflow.python.ops.metropolis_hastings_impl import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'kernel', 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', + 'proposal_uniform', + 'proposal_normal', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py index dc1ac68ce009fa46d6c05a3200a29d9fdf245707..05aa134ed5c11092316af5f3e45ba07fdb491e90 100644 --- a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions to create a Markov Chain Monte Carlo Metropolis step. +"""Metropolis-Hastings and proposal distributions. +@@kernel @@evolve -@@uniform_random_proposal -@@normal_random_proposal +@@proposal_uniform +@@proposal_normal """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -31,123 +34,198 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops __all__ = [ - 'evolve', - 'uniform_random_proposal', - 'normal_random_proposal', + "kernel", + "evolve", + "proposal_uniform", + "proposal_normal", ] -def _single_iteration(current_state, current_log_density, - log_unnormalized_prob_fn, proposal_fn, seed=None, - name='None'): - """Performs a single Metropolis-Hastings step. +KernelResults = collections.namedtuple( + "KernelResults", + [ + "log_accept_ratio", + "current_target_log_prob", # "Current result" means "accepted". + "is_accepted", + "proposed_state", + ]) + + +def kernel(target_log_prob_fn, + proposal_fn, + current_state, + seed=None, + current_target_log_prob=None, + name=None): + """Runs the Metropolis-Hastings transition kernel. + + This function can update multiple chains in parallel. It assumes that all + leftmost dimensions of `current_state` index independent chain states (and are + therefore updated independently). The output of `target_log_prob_fn()` should + sum log-probabilities across all event dimensions. Slices along the rightmost + dimensions may have different target distributions; for example, + `current_state[0, :]` could have a different target distribution from + `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of + independent chains is `tf.size(target_log_prob_fn(*current_state))`.) Args: - current_state: Float-like `Tensor` (i.e., `dtype` is either - `tf.float16`, `tf.float32` or `tf.float64`) of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` - callables. - current_log_density: Float-like `Tensor` with `dtype` and shape equivalent - to `log_unnormalized_prob_fn(current_state)`, i.e., matching the result of - `log_unnormalized_prob_fn` invoked at `current_state`. - log_unnormalized_prob_fn: A Python callable evaluated at - `current_state` and returning a float-like `Tensor` of log target-density - up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where - `target_density = g(x)/Z` for some constant `A`. The shape of the input - tensor is the same as the shape of the `current_state`. The shape of the - output tensor is either - (a). Same as the input shape if the density being sampled is one - dimensional, or - (b). If the density is defined for `events` of shape - `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of - shape `batch_shape + event_shape`, where `batch_shape = [B1, ..., Bb]` - and the result must be of shape [B1, ..., Bb]. For example, if the - distribution that is being sampled is a 10 dimensional normal, - then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` - and it should return tensors of shape [100] and [30, 20] respectively. - proposal_fn: A callable accepting a real valued `Tensor` of current sample - points and returning a tuple of two `Tensors`. The first element of the - pair is a `Tensor` containing the proposal state and should have - the same shape as the input `Tensor`. The second element of the pair gives - the log of the ratio of the probability of transitioning from the - proposal points to the input points and the probability of transitioning - from the input points to the proposal points. If the proposal is - symmetric (e.g., random walk, where the proposal is either - normal or uniform centered at `current_state`), i.e., - Probability(Proposal -> Current) = Probability(Current -> Proposal) - the second value should be set to `None` instead of explicitly supplying a - tensor of zeros. In addition to being convenient, this also leads to a - more efficient graph. - seed: `int` or None. The random seed for this `Op`. If `None`, no seed is - applied. - name: Python `str` name prefix for ops managed by this function. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + proposal_fn: Python callable which takes an argument like `current_state` + (or `*current_state` if it's a list) and returns a tuple of proposed + states of same shape as `state`, and a log ratio `Tensor` of same shape + as `current_target_log_prob`. The log ratio is the log-probability of + `state` given proposed states minus the log-probability of proposed + states given `state`. If the proposal is symmetric, set the second value + to `None`: this enables more efficient computation than explicitly + supplying a tensor of zeros. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to + specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: A name of the operation (optional). Returns: - next_state: `Tensor` with `dtype` and shape matching `current_state`. - Created by propagating the chain by one step, starting from + next_state: Tensor or Python list of `Tensor`s representing the state(s) + of the Markov chain(s) at each result step. Has same shape as `current_state`. - next_log_density: `Tensor` with `dtype` and shape matching - `current_log_density`, which is equal to the value of the unnormalized - `log_unnormalized_prob_fn` computed at `next_state`. - log_accept_ratio: `Tensor` with `dtype` and shape matching - `current_log_density`. Stands for the log of Metropolis-Hastings - acceptance ratio used in generating the `next_state`. - """ + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. - with ops.name_scope(name, 'single_iteration', [current_state]): - # The proposed state and the log of the corresponding Hastings ratio. - proposal_state, log_transit_ratio = proposal_fn(current_state) - - # If the log ratio is None, assume that the transitions are symmetric, - # i.e., Prob(Current -> Proposed) = Prob(Proposed -> Current). - if log_transit_ratio is None: - log_transit_ratio = 0. - - # Log-density of the proposal state. - proposal_log_density = log_unnormalized_prob_fn(proposal_state) - - # Ops to compute the log of the acceptance ratio. Recall that the - # acceptance ratio is: [Prob(Proposed) / Prob(Current)] * - # [Prob(Proposed -> Current) / Prob(Current -> Proposed)]. The log of the - # second term is the log_transit_ratio. - with ops.name_scope('accept_reject'): - # The log of the acceptance ratio. - log_accept_ratio = (proposal_log_density - current_log_density - + log_transit_ratio) - - # A proposal is accepted or rejected depending on the acceptance ratio. - # If the acceptance ratio is greater than 1 then it is always accepted. - # If the acceptance ratio is less than 1 then the proposal is accepted - # with probability = acceptance ratio. As we are working in log space to - # prevent over/underflows, this logic is expressed in log terms below. - # If a proposal is accepted we place a True in the acceptance state - # tensor and if it is to be rejected we place a False. - # The log_draws below have to be compared to the log_accept_ratio so we - # make sure that they have the same data type. - log_draws = math_ops.log(random_ops.random_uniform( - array_ops.shape(current_log_density), seed=seed, - dtype=log_accept_ratio.dtype)) - is_proposal_accepted = log_draws < log_accept_ratio - - # The acceptance state decides which elements of the current state are to - # be replaced with the corresponding elements in the proposal state. - with ops.name_scope(name, 'metropolis_single_step', - [current_state, current_log_density]): - next_log_density = array_ops.where(is_proposal_accepted, - proposal_log_density, - current_log_density) - next_state = array_ops.where(is_proposal_accepted, proposal_state, - current_state) - - return next_state, next_log_density, log_accept_ratio + #### Examples + + We illustrate Metropolis-Hastings on a Normal likelihood with + unknown mean. + + ```python + tfd = tf.contrib.distributions + tfp = tf.contrib.bayesflow + + loc = tf.get_variable("loc", initializer=1.) + x = tf.constant([0.0] * 50) + + def make_target_log_prob_fn(x): + def target_log_prob_fn(loc): + prior = tfd.Normal(loc=0., scale=1.) + likelihood = tfd.Independent( + tfd.Normal(loc=loc, scale=0.1), + reinterpreted_batch_ndims=1) + return prior.log_prob(loc) + likelihood.log_prob(x) + return target_log_prob_fn + + next_state, kernel_results = tfp.metropolis_hastings.kernel( + target_log_prob_fn=make_target_log_prob_fn(x), + proposal_fn=tfp.metropolis_hastings.proposal_normal(), + current_state=loc) + loc_update = loc.assign(next_state) + ``` + + We illustrate Metropolis-Hastings on a Normal likelihood with + unknown mean and variance. We apply 4 chains. + + ```python + tfd = tf.contrib.distributions + tfp = tf.contrib.bayesflow + + num_chains = 4 + loc = tf.get_variable("loc", shape=[num_chains], + initializer=tf.random_normal_initializer()) + scale = tf.get_variable("scale", shape=[num_chains], + initializer=tf.ones_initializer()) + x = tf.constant([0.0] * 50) + + def make_target_log_prob_fn(x): + data = tf.reshape(x, shape=[-1, 1]) + def target_log_prob_fn(loc, scale): + prior_loc = tfd.Normal(loc=0., scale=1.) + prior_scale = tfd.InverseGamma(concentration=1., rate=1.) + likelihood = tfd.Independent( + tfd.Normal(loc=loc, scale=scale), + reinterpreted_batch_ndims=1) + return (prior_loc.log_prob(loc) + + prior_scale.log_prob(scale) + + likelihood.log_prob(data)) + return target_log_prob_fn + + def proposal_fn(loc, scale): + loc_proposal = tfp.metropolis_hastings.proposal_normal() + scale_proposal = tfp.metropolis_hastings.proposal_uniform(minval=-1.) + proposed_loc, _ = loc_proposal(loc) + proposed_scale, _ = scale_proposal(scale) + proposed_scale = tf.maximum(proposed_scale, 0.01) + return [proposed_loc, proposed_scale], None + + next_state, kernel_results = tfp.metropolis_hastings.kernel( + target_log_prob_fn=make_target_log_prob_fn(x), + proposal_fn=proposal_fn, + current_state=[loc, scale]) + train_op = tf.group(loc.assign(next_state[0]), + scale.assign(next_state[1])) + ``` + + """ + with ops.name_scope( + name, "metropolis_hastings_kernel", + [current_state, seed, current_target_log_prob]): + with ops.name_scope("initialize"): + maybe_expand = lambda x: list(x) if _is_list_like(x) else [x] + current_state_parts = maybe_expand(current_state) + if current_target_log_prob is None: + current_target_log_prob = target_log_prob_fn(*current_state_parts) + + proposed_state, log_transit_ratio = proposal_fn(*current_state_parts) + proposed_state_parts = maybe_expand(proposed_state) + + proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) + + with ops.name_scope( + "accept_reject", + [current_state_parts, proposed_state_parts, + current_target_log_prob, proposed_target_log_prob]): + log_accept_ratio = proposed_target_log_prob - current_target_log_prob + if log_transit_ratio is not None: + # If the log_transit_ratio is None, then assume the proposal is + # symmetric, i.e., + # log p(old | new) - log p(new | old) = 0. + log_accept_ratio += log_transit_ratio + + # u < exp(log_accept_ratio), where u~Uniform[0,1) + # ==> log(u) < log_accept_ratio + random_value = random_ops.random_uniform( + array_ops.shape(log_accept_ratio), + dtype=log_accept_ratio.dtype, + seed=seed) + random_negative = math_ops.log(random_value) + is_accepted = random_negative < log_accept_ratio + next_state_parts = [array_ops.where(is_accepted, + proposed_state_part, + current_state_part) + for proposed_state_part, current_state_part in + zip(proposed_state_parts, current_state_parts)] + accepted_log_prob = array_ops.where(is_accepted, + proposed_target_log_prob, + current_target_log_prob) + maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] + return [ + maybe_flatten(next_state_parts), + KernelResults( + log_accept_ratio=log_accept_ratio, + current_target_log_prob=accepted_log_prob, + is_accepted=is_accepted, + proposed_state=maybe_flatten(proposed_state_parts), + ), + ] def evolve(initial_sample, initial_log_density, initial_log_accept_ratio, - log_unnormalized_prob_fn, + target_log_prob_fn, proposal_fn, n_steps=1, seed=None, @@ -162,9 +240,11 @@ def evolve(initial_sample, The probability distribution may have an unknown normalization constan. We parameterize the probability density as follows: - ``` - f(x) = exp(L(x) + constant) - ``` + + ```none + f(x) = exp(L(x) + constant) + ``` + Here `L(x)` is any continuous function with an (possibly unknown but finite) upper bound, i.e. there exists a number beta such that `L(x)< beta < infinity` for all x. The constant is the normalization needed @@ -188,72 +268,77 @@ def evolve(initial_sample, The following example, demonstrates the use to generate a 1000 uniform random walk Metropolis samplers run in parallel for the normal target distribution. + ```python - n = 3 # dimension of the problem - - # Generate 1000 initial values randomly. Each of these would be an - # independent starting point for a Markov chain. - state = tf.get_variable( - 'state',initializer=tf.random_normal([1000, n], mean=3.0, - dtype=tf.float64, seed=42)) - - # Computes the log(p(x)) for the unit normal density and ignores the - # normalization constant. - def log_density(x): - return - tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 - - # Initial log-density value - state_log_density = tf.get_variable( - 'state_log_density', initializer=log_density(state.initialized_value())) - - # A variable to store the log_acceptance_ratio: - log_acceptance_ratio = tf.get_variable( - 'log_acceptance_ratio', initializer=tf.zeros([1000], dtype=tf.float64)) - - # Generates random proposals by moving each coordinate uniformly and - # independently in a box of size 2 centered around the current value. - # Returns the new point and also the log of the Hastings ratio (the - # ratio of the probability of going from the proposal to origin and the - # probability of the reverse transition). When this ratio is 1, the value - # may be omitted and replaced by None. - def random_proposal(x): - return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, - dtype=x.dtype, seed=12)), None - - # Create the op to propagate the chain for 100 steps. - stepper = mh.evolve( - state, state_log_density, log_acceptance_ratio, - log_density, random_proposal, n_steps=100, seed=123) - init = tf.initialize_all_variables() - with tf.Session() as sess: - sess.run(init) - # Run the chains for a total of 1000 steps and print out the mean across - # the chains every 100 iterations. - for n_iter in range(10): - # Executing the stepper advances the chain to the next state. - sess.run(stepper) - # Print out the current value of the mean(sample) for every dimension. - print(np.mean(sess.run(state), 0)) - # Estimated covariance matrix - samples = sess.run(state) - print('') - print(np.cov(samples, rowvar=False)) + n = 3 # dimension of the problem + + # Generate 1000 initial values randomly. Each of these would be an + # independent starting point for a Markov chain. + state = tf.get_variable( + "state", + initializer=tf.random_normal([1000, n], + mean=3.0, + dtype=tf.float64, + seed=42)) + + # Computes the log(p(x)) for the unit normal density and ignores the + # normalization constant. + def log_density(x): + return -tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 + + # Initial log-density value + state_log_density = tf.get_variable( + "state_log_density", + initializer=log_density(state.initialized_value())) + + # A variable to store the log_acceptance_ratio: + log_acceptance_ratio = tf.get_variable( + "log_acceptance_ratio", + initializer=tf.zeros([1000], dtype=tf.float64)) + + # Generates random proposals by moving each coordinate uniformly and + # independently in a box of size 2 centered around the current value. + # Returns the new point and also the log of the Hastings ratio (the + # ratio of the probability of going from the proposal to origin and the + # probability of the reverse transition). When this ratio is 1, the value + # may be omitted and replaced by None. + def random_proposal(x): + return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, + dtype=x.dtype, seed=12)), None + + # Create the op to propagate the chain for 100 steps. + stepper = mh.evolve( + state, state_log_density, log_acceptance_ratio, + log_density, random_proposal, n_steps=100, seed=123) + init = tf.initialize_all_variables() + with tf.Session() as sess: + sess.run(init) + # Run the chains for a total of 1000 steps and print out the mean across + # the chains every 100 iterations. + for n_iter in range(10): + # Executing the stepper advances the chain to the next state. + sess.run(stepper) + # Print out the current value of the mean(sample) for every dimension. + print(np.mean(sess.run(state), 0)) + # Estimated covariance matrix + samples = sess.run(state) + print(np.cov(samples, rowvar=False)) ``` Args: initial_sample: A float-like `tf.Variable` of any shape that can - be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` + be consumed by the `target_log_prob_fn` and `proposal_fn` callables. initial_log_density: Float-like `tf.Variable` with `dtype` and shape - equivalent to `log_unnormalized_prob_fn(initial_sample)`, i.e., matching - the result of `log_unnormalized_prob_fn` invoked at `current_state`. + equivalent to `target_log_prob_fn(initial_sample)`, i.e., matching + the result of `target_log_prob_fn` invoked at `current_state`. initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching `initial_log_density`. Stands for the log of Metropolis-Hastings acceptance ratio after propagating the chain for `n_steps`. - log_unnormalized_prob_fn: A Python callable evaluated at + target_log_prob_fn: A Python callable evaluated at `current_state` and returning a float-like `Tensor` of log target-density up to a normalizing constant. In other words, - `log_unnormalized_prob_fn(x) = log(g(x))`, where + `target_log_prob_fn(x) = log(g(x))`, where `target_density = g(x)/Z` for some constant `A`. The shape of the input tensor is the same as the shape of the `current_state`. The shape of the output tensor is either @@ -265,7 +350,7 @@ def evolve(initial_sample, and the result must be of shape [B1, ..., Bb]. For example, if the distribution that is being sampled is a 10 dimensional normal, then the input tensor may be of shape [100, 10] or [30, 20, 10]. The - last dimension will then be 'consumed' by `log_unnormalized_prob_fn` + last dimension will then be 'consumed' by `target_log_prob_fn` and it should return tensors of shape [100] and [30, 20] respectively. proposal_fn: A callable accepting a real valued `Tensor` of current sample points and returning a tuple of two `Tensors`. The first element of the @@ -289,42 +374,48 @@ def evolve(initial_sample, forward_step: an `Op` to step the Markov chain forward for `n_steps`. """ - with ops.name_scope(name, 'metropolis_hastings', [initial_sample]): + with ops.name_scope(name, "metropolis_hastings", [initial_sample]): current_state = initial_sample - current_log_density = initial_log_density + current_target_log_prob = initial_log_density log_accept_ratio = initial_log_accept_ratio - # Stop condition for the while_loop - def stop_condition(i, _): - return i < n_steps - - def step(i, loop_vars): - """Wrap `_single_iteration` for `while_loop`.""" - state = loop_vars[0] - state_log_density = loop_vars[1] - return i + 1, list(_single_iteration(state, state_log_density, - log_unnormalized_prob_fn, - proposal_fn, seed=seed)) - - loop_vars = [current_state, current_log_density, log_accept_ratio] - # Build an `Op` to evolve the Markov chain for `n_steps` - (_, [end_state, end_log_density, end_log_acceptance]) = ( + def step(i, current_state, current_target_log_prob, log_accept_ratio): + """Wrap single Markov chain iteration in `while_loop`.""" + next_state, kernel_results = kernel( + target_log_prob_fn=target_log_prob_fn, + proposal_fn=proposal_fn, + current_state=current_state, + current_target_log_prob=current_target_log_prob, + seed=seed) + accepted_log_prob = kernel_results.current_target_log_prob + log_accept_ratio = kernel_results.log_accept_ratio + return i + 1, next_state, accepted_log_prob, log_accept_ratio + + (_, accepted_state, accepted_target_log_prob, accepted_log_accept_ratio) = ( control_flow_ops.while_loop( - stop_condition, step, - (0, loop_vars), - parallel_iterations=1, swap_memory=1)) + cond=lambda i, *ignored_args: i < n_steps, + body=step, + loop_vars=[ + 0, # i + current_state, + current_target_log_prob, + log_accept_ratio, + ], + parallel_iterations=1 if seed is not None else 10, + # TODO(b/73775595): Confirm optimal setting of swap_memory. + swap_memory=1)) forward_step = control_flow_ops.group( - state_ops.assign(current_log_density, end_log_density), - state_ops.assign(current_state, end_state), - state_ops.assign(log_accept_ratio, end_log_acceptance)) + state_ops.assign(current_target_log_prob, accepted_target_log_prob), + state_ops.assign(current_state, accepted_state), + state_ops.assign(log_accept_ratio, accepted_log_accept_ratio)) return forward_step -def uniform_random_proposal(step_size=1., - seed=None, - name=None): +def proposal_uniform(step_size=1., + seed=None, + name=None): """Returns a callable that adds a random uniform tensor to the input. This function returns a callable that accepts one `Tensor` argument of any @@ -346,11 +437,13 @@ def uniform_random_proposal(step_size=1., Returns: proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. + 2-tuple. The first value in the tuple is a `Tensor` of the same shape and + dtype as the input argument and the second element of the tuple is None. """ - with ops.name_scope(name, 'uniform_random_proposal', [step_size]): + with ops.name_scope(name, "proposal_uniform", [step_size]): + step_size = ops.convert_to_tensor(step_size, name="step_size") + def proposal_fn(input_state, name=None): """Adds a uniform perturbation to the input state. @@ -359,12 +452,12 @@ def uniform_random_proposal(step_size=1., name: A string that sets the name for this `Op`. Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching + proposal_state: A float-like `Tensor` with `dtype` and shape matching `input_state`. log_transit_ratio: `None`. Proposal is symmetric. """ - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') + with ops.name_scope(name, "proposer", [input_state]): + input_state = ops.convert_to_tensor(input_state, name="input_state") return input_state + random_ops.random_uniform( array_ops.shape(input_state), minval=-step_size, @@ -373,9 +466,9 @@ def uniform_random_proposal(step_size=1., return proposal_fn -def normal_random_proposal(scale=1., - seed=None, - name=None): +def proposal_normal(scale=1., + seed=None, + name=None): """Returns a callable that adds a random normal tensor to the input. This function returns a callable that accepts one `Tensor` argument of any @@ -398,11 +491,13 @@ def normal_random_proposal(scale=1., Returns: proposal_fn: A callable accepting one float-like `Tensor` and returning a - 2-tuple. The first value in the tuple is a `Tensor` of the same shape and - dtype as the input argument and the second element of the tuple is None. + 2-tuple. The first value in the tuple is a `Tensor` of the same shape and + dtype as the input argument and the second element of the tuple is None. """ - with ops.name_scope(name, 'normal_random_proposal', [scale]): + with ops.name_scope(name, "proposal_normal", [scale]): + scale = ops.convert_to_tensor(scale, name="scale") + def proposal_fn(input_state, name=None): """Adds a normal perturbation to the input state. @@ -411,16 +506,22 @@ def normal_random_proposal(scale=1., name: A string that sets the name for this `Op`. Returns: - proposal_state: A float-like `Tensot` with `dtype` and shape matching + proposal_state: A float-like `Tensor` with `dtype` and shape matching `input_state`. log_transit_ratio: `None`. Proposal is symmetric. """ - with ops.name_scope(name, 'proposer', [input_state]): - input_state = ops.convert_to_tensor(input_state, name='input_state') + with ops.name_scope(name, "proposer", [input_state]): + input_state = ops.convert_to_tensor(input_state, name="input_state") return input_state + random_ops.random_normal( array_ops.shape(input_state), mean=0., stddev=scale, + dtype=scale.dtype, seed=seed), None return proposal_fn + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 31f5c444817b9b82723c86bea3504d4934e57eb8..23ba76210b3b68d0d0b2eef9d4040882654bdad9 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -93,7 +93,9 @@ def make_custom_export_strategy(name, "w") as f: f.write("\n".join("%s, %f" % (k, v) for k, v in sorted_by_importance)) return result_dir - return export_strategy.ExportStrategy(name, export_fn) + + return export_strategy.ExportStrategy( + name, export_fn, strip_default_attrs=True) def convert_to_universal_format(dtec, sorted_feature_names, diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 3ebf28ea442edf87815c39971ae9e01a2a8aae9a..94aeb2c7bb48c6eddb6c7894f8bf6f1567470113 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -126,7 +126,8 @@ class DecisionTreeEnsembleResource : public StampedResource { return; } used_ids->Add(handler_id); - std::rotate(first, used_ids->end() - 1, used_ids->end()); + // Keep the list of used handlers sorted. + std::sort(used_ids->begin(), used_ids->end()); } std::vector GetUsedHandlers() const { diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 80e18a43a71cc9d6c9e2ccf5836e50c6427a30f6..1a124eca364424b651de86bfaac6f33ad131804b 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -30,6 +30,7 @@ py_library( "python/training/__init__.py", ], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":cluster_resolver_py", ":gce_cluster_resolver_py", @@ -109,5 +110,6 @@ tf_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:training", ], + grpc_enabled = True, main = "python/training/tpu_cluster_resolver_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py index b04822fa9d66465e34a545d3b00c399bbb196514..1c480b25134b1e54200e0ddb780bd7bb0f122341 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver.py @@ -53,11 +53,16 @@ class ClusterResolver(object): raise NotImplementedError( 'cluster_spec is not implemented for {}.'.format(self)) + @abc.abstractmethod + def master(self): + """...""" + raise NotImplementedError('master is not implemented for {}.'.format(self)) + class SimpleClusterResolver(ClusterResolver): """Simple implementation of ClusterResolver that accepts a ClusterSpec.""" - def __init__(self, cluster_spec): + def __init__(self, cluster_spec, master=''): """Creates a SimpleClusterResolver from a ClusterSpec.""" super(SimpleClusterResolver, self).__init__() @@ -65,10 +70,18 @@ class SimpleClusterResolver(ClusterResolver): raise TypeError('cluster_spec must be a ClusterSpec.') self._cluster_spec = cluster_spec + if not isinstance(master, str): + raise TypeError('master must be a string.') + self._master = master + def cluster_spec(self): """Returns the ClusterSpec passed into the constructor.""" return self._cluster_spec + def master(self): + """Returns the master address to use when creating a session.""" + return self._master + class UnionClusterResolver(ClusterResolver): """Performs a union on underlying ClusterResolvers. @@ -87,9 +100,13 @@ class UnionClusterResolver(ClusterResolver): Raises: TypeError: If any argument is not a subclass of `ClusterResolvers`. + ValueError: If there are no arguments passed. """ super(UnionClusterResolver, self).__init__() + if not args: + raise ValueError('At least one ClusterResolver is required.') + for cluster_resolver in args: if not isinstance(cluster_resolver, ClusterResolver): raise TypeError('All arguments must be a sub-class of ' @@ -169,3 +186,7 @@ class UnionClusterResolver(ClusterResolver): merged_cluster[job_name].update(task_dict) return ClusterSpec(merged_cluster) + + def master(self): + """master returns the master address from the first cluster resolver.""" + return self._cluster_resolvers[0].master() diff --git a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py index dbfb77723cdaab66e29bb41b764593bb5fd61b35..d9c97d53eb3663f6ab2f7b40395592dc7638b896 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/cluster_resolver_test.py @@ -234,5 +234,7 @@ class UnionClusterResolverTest(test.TestCase): self._verifyClusterSpecEquality(cluster_spec, expected_proto) +# TODO(saeta): Include tests for master resolution + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py index d6f2eced93ba4fda5ac27f9412b6f729981f4f40..3f5824128948453634bc5e5a7d6fdeedae60f5bd 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/gce_cluster_resolver.py @@ -134,3 +134,6 @@ class GceClusterResolver(ClusterResolver): worker_list.sort() return ClusterSpec({self._job_name: worker_list}) + + def master(self): + return '' diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index a6a6e642e4e4c721b94821a70d55d6fe931347d6..aeccf4c06bb57a03ac79e20a5e001935d847b2a7 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -23,7 +23,8 @@ from six.moves.urllib.request import Request from six.moves.urllib.request import urlopen from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver -from tensorflow.python.training.server_lib import ClusterSpec +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat _GOOGLE_API_CLIENT_INSTALLED = True try: @@ -46,13 +47,23 @@ class TPUClusterResolver(ClusterResolver): req = Request('http://metadata/computeMetadata/v1/%s' % path, headers={'Metadata-Flavor': 'Google'}) resp = urlopen(req) - return resp.read() + return compat.as_bytes(resp.read()) + + def _shouldResolve(self): + if (self._tpu == compat.as_bytes('') or + self._tpu == compat.as_bytes('local') or + self._tpu.startswith(compat.as_bytes('/bns')) or + self._tpu.startswith(compat.as_bytes('grpc://'))): + return False + return True def __init__(self, - tpu_names, + tpu, zone=None, project=None, - job_name='tpu_worker', + job_name='worker', + coordinator_name='coordinator', + coordinator_address=None, credentials='default', service=None): """Creates a new TPUClusterResolver object. @@ -61,7 +72,11 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - tpu_names: A list of names of the target Cloud TPUs. + tpu: Either a string, or a list of strings corresponding to the TPUs to + use. If the single string is the empty string, the string 'local', or a + string that begins with 'grpc://' or '/bns', then it is assumed to not + correspond with a Cloud TPU and will instead be passed as the session + master and no ClusterSpec propagation will be done. zone: Zone where the TPUs are located. If omitted or empty, we will assume that the zone of the TPU is the same as the zone of the GCE VM, which we will try to discover from the GCE metadata service. @@ -69,6 +84,12 @@ class TPUClusterResolver(ClusterResolver): empty, we will try to discover the project name of the GCE VM from the GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. + coordinator_name: The name to use for the coordinator. Set to None if the + coordinator should not be included in the computed ClusterSpec. + coordinator_address: The address of the coordinator (typically an ip:port + pair). If set to None, a TF server will be started. If coordinator_name + is None, a TF server will not be started even if coordinator_address is + None. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client service: The GCE API object returned by the googleapiclient.discovery @@ -77,26 +98,36 @@ class TPUClusterResolver(ClusterResolver): Raises: ImportError: If the googleapiclient is not installed. + ValueError: If no TPUs are specified. """ + if isinstance(tpu, list): + if not tpu: + raise ValueError('At least one TPU must be specified.') + if len(tpu) != 1: + raise NotImplementedError( + 'Using multiple TPUs in a single session is not yet implemented') + tpu = tpu[0] + self._tpu = compat.as_bytes(tpu) # self._tpu is always bytes + self._job_name = job_name + self._credentials = credentials - if not project: - project = self._requestComputeMetadata('/project/project-id') + should_resolve = self._shouldResolve() - if not zone: - zone_path = self._requestComputeMetadata('/instance/zone') + if not project and should_resolve: + project = self._requestComputeMetadata('project/project-id') + + if not zone and should_resolve: + zone_path = self._requestComputeMetadata('instance/zone') zone = zone_path.split('/')[-1] self._project = project self._zone = zone - self._tpu_names = tpu_names - self._job_name = job_name - self._credentials = credentials - if credentials == 'default': + if credentials == 'default' and should_resolve: if _GOOGLE_API_CLIENT_INSTALLED: self._credentials = GoogleCredentials.get_application_default() - if service is None: + if service is None and should_resolve: if not _GOOGLE_API_CLIENT_INSTALLED: raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') @@ -107,25 +138,41 @@ class TPUClusterResolver(ClusterResolver): else: self._service = service - def get_master(self): - """Get the ClusterSpec grpc master path. + self._coordinator_name = coordinator_name + if coordinator_name and not coordinator_address and should_resolve: + self._start_local_server() + else: + self._coordinator_address = coordinator_address + + def master(self): + """Get the Master string to be used for the session. + + In the normal case, this returns the grpc path (grpc://1.2.3.4:8470) of + first instance in the ClusterSpec returned by the cluster_spec function. - This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the - ClusterSpec returned by the cluster_spec function. This is suitable for use - for the `master` argument in tf.Session() when you are using one TPU. + If a non-TPU name is used when constructing a TPUClusterResolver, that will + be returned instead (e.g. If the tpus argument's value when constructing + this TPUClusterResolver was 'grpc://10.240.1.2:8470', + 'grpc://10.240.1.2:8470' will be returned). Returns: - string, the grpc path of the first instance in the ClusterSpec. + string, the connection string to use when creating a session. Raises: ValueError: If none of the TPUs specified exists. """ + if not self._shouldResolve(): + return self._tpu + job_tasks = self.cluster_spec().job_tasks(self._job_name) if not job_tasks: raise ValueError('No TPUs exists with the specified names exist.') return 'grpc://' + job_tasks[0] + def get_master(self): + return self.master() + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. @@ -134,17 +181,54 @@ class TPUClusterResolver(ClusterResolver): Returns: A ClusterSpec containing host information returned from Cloud TPUs. - """ - worker_list = [] - - for tpu_name in self._tpu_names: - full_name = 'projects/%s/locations/%s/nodes/%s' % ( - self._project, self._zone, tpu_name) - request = self._service.projects().locations().nodes().get(name=full_name) - response = request.execute() - if 'health' in response and response['health'] == 'HEALTHY': - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) - - return ClusterSpec({self._job_name: worker_list}) + Raises: + RuntimeError: If the provided TPU is not healthy. + """ + if not self._shouldResolve(): + return server_lib.ClusterSpec({}) + + full_name = 'projects/%s/locations/%s/nodes/%s' % ( + self._project, self._zone, compat.as_text(self._tpu)) + request = self._service.projects().locations().nodes().get(name=full_name) + response = request.execute() + + if 'health' in response and response['health'] != 'HEALTHY': + raise RuntimeError('TPU "%s" is unhealthy: "%s"' % (self._tpu, + response['health'])) + + if 'networkEndpoints' in response: + worker_list = [ + '%s:%s' % (endpoint['ipAddress'], endpoint['port']) + for endpoint in response['networkEndpoints'] + ] + else: + # Fall back to the deprecated response format + instance_url = '%s:%s' % (response['ipAddress'], response['port']) + worker_list = [instance_url] + + cluster_spec = {self._job_name: worker_list} + + if self._coordinator_address: + cluster_spec[self._coordinator_name] = [self._coordinator_address] + + return server_lib.ClusterSpec(cluster_spec) + + def _start_local_server(self): + address = self._requestComputeMetadata('instance/network-interfaces/0/ip') + self._server = server_lib.Server( + { + 'local': ['0.0.0.0:0'] + }, protocol='grpc', config=None, start=True) + # self._server.target is of the form: grpc://ipaddress:port + target = compat.as_bytes(self._server.target) + splits = target.split(compat.as_bytes(':')) + assert len(splits) == 3, self._server.target + assert splits[0] == compat.as_bytes('grpc'), self._server.target + self._coordinator_port = compat.as_text(splits[2]) + self._coordinator_address = '%s:%s' % ( + address, compat.as_text(self._coordinator_port)) + + def __deepcopy__(self, memo): + # TODO(b/73668574): Remove this once RunConfig avoids performing deepcopy. + return self diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 4fd34629cf74f90869c77b8cb098d3c585a49404..6b4a15515262b35e3cf8d7d2943e06d86b870ca9 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -21,7 +21,7 @@ from __future__ import print_function from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver from tensorflow.python.platform import test from tensorflow.python.training import server_lib - +from tensorflow.python.util import compat mock = test.mock @@ -50,10 +50,12 @@ class MockNodeClass(object): def mock_request_compute_metadata(cls, *args, **kwargs): del cls, kwargs # Unused. - if args[0] == '/project/project-id': + if args[0] == 'project/project-id': return 'test-project' - elif args[0] == '/instance/zone': + elif args[0] == 'instance/zone': return 'projects/test-project/locations/us-central1-c' + elif args[0] == 'instance/network-interfaces/0/ip': + return '10.128.1.2' return '' @@ -113,17 +115,26 @@ class TPUClusterResolverTest(test.TestCase): tpu_cluster_resolver = TPUClusterResolver( project=None, zone=None, - tpu_names=['test-tpu-1'], + tpu=['test-tpu-1'], credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } - """ - self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + job { + name: 'coordinator' + tasks { key: 0 value: '10.128.1.2:%s' } + } + job { + name: 'worker' + tasks { key: 0 value: '10.1.2.3:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) - def testSimpleSuccessfulRetrieval(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadataNoCoordinator(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', @@ -133,116 +144,217 @@ class TPUClusterResolverTest(test.TestCase): } tpu_cluster_resolver = TPUClusterResolver( - project='test-project', - zone='us-central1-c', - tpu_names=['test-tpu-1'], + project=None, + zone=None, + tpu=['test-tpu-1'], + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testMultipleSuccessfulRetrieval(self): + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', 'port': '8470', 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu=['test-tpu-1'], + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { name: 'tpu_worker' tasks { key: 0 value: '10.4.5.6:8470' } - tasks { key: 1 value: '10.1.2.3:8470' } } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) - def testHealthyTpuNodeRetrieval(self): + def testNewNetworkEndpointFormat(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { - 'ipAddress': '10.7.8.9', - 'port': '8470', - 'health': 'UNHEALTHY' + 'health': 'HEALTHY', + 'networkEndpoints': [{ + 'ipAddress': '10.2.3.4', + 'port': 8470, + }] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + tpu='test-tpu-1', + coordinator_address='10.128.1.5:10203', credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) actual_cluster_spec = tpu_cluster_resolver.cluster_spec() expected_proto = """ - job { - name: 'tpu_worker' - tasks { - key: 0 - value: '10.1.2.3:8470' - } - } + job { name: 'coordinator' tasks { key: 0 value: '10.128.1.5:10203' } } + job { name: 'worker' tasks { key: 0 value: '10.2.3.4:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + self.assertEqual('grpc://10.2.3.4:8470', tpu_cluster_resolver.master()) - def testGetMasterMultipleEntries(self): + @mock.patch.object(TPUClusterResolver, '_requestComputeMetadata', + mock_request_compute_metadata) + def testPodResolution(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { - 'ipAddress': '10.1.2.3', - 'port': '8470', - 'health': 'HEALTHY' - }, - 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { - 'ipAddress': '10.4.5.6', - 'port': '8470', - 'health': 'HEALTHY' + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] + } + } + + tpu_cluster_resolver = TPUClusterResolver( + tpu='test-tpu-1', + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'coordinator', + tasks { key: 0 value: '10.128.1.2:%s'} + } + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ % tpu_cluster_resolver._coordinator_port + self._verifyClusterSpecEquality(actual_cluster_spec, str(expected_proto)) + + def testPodResolutionNoCoordinator(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'health': + 'HEALTHY', + 'networkEndpoints': [ + { + 'ipAddress': '10.2.3.4', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.5', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.6', + 'port': 8470, + }, + { + 'ipAddress': '10.2.3.7', + 'port': 8470, + }, + ] } } tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=['test-tpu-2', 'test-tpu-1'], + tpu='test-tpu-1', + coordinator_name=None, credentials=None, service=self.mock_service_client(tpu_map=tpu_map)) - self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'worker' + tasks { key: 0 value: '10.2.3.4:8470' } + tasks { key: 1 value: '10.2.3.5:8470' } + tasks { key: 2 value: '10.2.3.6:8470' } + tasks { key: 3 value: '10.2.3.7:8470' } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) def testGetMasterNoEntries(self): tpu_map = {} + with self.assertRaises(ValueError): + TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu=[], + coordinator_name=None, + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + # TODO(saeta): Convert to parameterized test when included in OSS TF. + def verifyShouldResolve(self, tpu, should_resolve): tpu_cluster_resolver = TPUClusterResolver( project='test-project', zone='us-central1-c', - tpu_names=[], + tpu=tpu, + coordinator_name=None, credentials=None, - service=self.mock_service_client(tpu_map=tpu_map)) - with self.assertRaises(ValueError): - tpu_cluster_resolver.get_master() + service=self.mock_service_client(tpu_map={})) + self.assertEqual(should_resolve, tpu_cluster_resolver._shouldResolve(), + "TPU: '%s'" % tpu) + + def testShouldResolveNoName(self): + self.verifyShouldResolve('', False) + + def testShouldResolveLocal(self): + self.verifyShouldResolve('local', False) + + def testShouldResolveGrpc(self): + self.verifyShouldResolve('grpc://10.1.2.3:8470', False) + + def testShouldResolveBns(self): + self.verifyShouldResolve('/bns/foo/bar', False) + + def testShouldResolveName(self): + self.verifyShouldResolve('mytpu', True) + + def testShouldResolveList(self): + self.verifyShouldResolve(['myothertpu'], True) + + def testShouldResolveGrpcPrefix(self): + self.verifyShouldResolve('grpctpu', True) + + def testNoCallComputeMetadata(self): + tpu_cluster_resolver = TPUClusterResolver(tpu='/bns/foo/bar') + self.assertEqual(compat.as_bytes('/bns/foo/bar'), + tpu_cluster_resolver.master()) + self.assertEqual( + server_lib.ClusterSpec({}), tpu_cluster_resolver.cluster_spec()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index 16317f538f3890661f1b59ea39fe67dcf04d0d0a..23b31ae1dcc83d8a7152354ac147de9ada320429 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -341,7 +341,8 @@ if (tensorflow_ENABLE_GPU) if(NOT CUDNN_HOME) set(CUDNN_HOME ${CUDA_TOOLKIT_TARGET_DIR}) endif(NOT CUDNN_HOME) - include_directories(${CUDNN_HOME}) + set(CUDNN_INCLUDE "${CUDNN_HOME}/include") + set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${CUDNN_HOME}/lib/x64/cudnn.lib) else (WIN32) @@ -369,10 +370,10 @@ if (tensorflow_ENABLE_GPU) message("culibos-static: ${culibos_STATIC_LIBRARY}") endif (NOT culibos_STATIC_LIBRARY) - include_directories(${CUDNN_INCLUDE}) set(CUDA_LIBRARIES ${CUDA_LIBRARIES} ${CUDA_CUDA_LIBRARY} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_CUFFT_LIBRARIES} ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + include_directories(${CUDNN_INCLUDE}) # Remove "." from CUDA version variable. string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) @@ -388,31 +389,22 @@ if (tensorflow_ENABLE_GPU) "#endif // CUDA_CUDA_CONFIG_H_\n" ) - if (WIN32) - # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk - # installs them under cuda/version/include and to avoid that we need to change tf we copy a - # few files to cuda/include - FILE(COPY - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_HOME}/include/cudnn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h - DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include - ) - else(WIN32) - # Linux has slightly differnt install paths than Windows - FILE(COPY - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h ${CUDNN_INCLUDE}/cudnn.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h - ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h - DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include - ) - endif(WIN32) + # tf assumes in various places header files to be in cuda/include. On windows the cuda sdk + # installs them under cuda/version/include and to avoid that we need to change tf we copy a + # few files to cuda/include + FILE(COPY + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuComplex.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cublas_v2.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h + ${CUDNN_INCLUDE}/cudnn.h + DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include + ) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index 836889895567f679d9960e29ece1600d1a7a58eb..98a8c7e736e5c8c407b90e8eac440cdc7ab21579 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) -set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip) +set(cub_HASH SHA256=6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c index 968ab13a0c43793341431248713f81010c87f148..9e355da33a7258119b6086216f5487d7ea94716c 100644 --- a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c @@ -1,3 +1,18 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + // This is a program to test if compiler is compatible with CUDA. #define __CUDACC__ #include "crt/host_config.h" diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc index 968ab13a0c43793341431248713f81010c87f148..beb574061bea8d04af8386223749677ae36a5d9b 100644 --- a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc @@ -1,3 +1,18 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================*/ + // This is a program to test if compiler is compatible with CUDA. #define __CUDACC__ #include "crt/host_config.h" diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 96ac60d095dbc84470ff1be92f4bf52bb420fc52..a54cbff33b66d63d7229fa2f50b8a4ca962111ed 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -63,6 +63,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" "${tensorflow_source_dir}/tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.cc" ) +file(GLOB_RECURSE tf_core_cpu_whitelisted_srcs + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.h" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id.cc" + "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc" +) +list(REMOVE_ITEM tf_core_cpu_exclude_srcs ${tf_core_cpu_whitelisted_srcs}) list(REMOVE_ITEM tf_core_cpu_srcs ${tf_core_cpu_exclude_srcs}) if (tensorflow_ENABLE_GPU) @@ -79,6 +85,7 @@ if (tensorflow_ENABLE_GPU) "${tensorflow_source_dir}/tensorflow/core/*test*.cc" ) list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_gpu_exclude_srcs}) + list(REMOVE_ITEM tf_core_gpu_srcs ${tf_core_cpu_whitelisted_srcs}) list(APPEND tf_core_cpu_srcs ${tf_core_gpu_srcs}) endif() diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index f219d5eb577afa9edaadca09aef9869c81d2bd87..998f99ecc19f88921dce14fde892912fb699ad08 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -71,6 +71,8 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/unique_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 799bb8e58e3c200a141fe33ae6b4710a61f7bd78..1c4ebd7f0c1113bcd0857fb0858df2248499f920 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -276,8 +276,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" - # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 62708636c6181ca63cddf2b2e7c84d3da740282a..1233c8f251c404c57d9e2b38993e7a386b1e6ceb 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -105,8 +105,8 @@ def crf_sequence_score(inputs, tag_indices, sequence_lengths, return utils.smart_cond( pred=math_ops.equal(inputs.shape[1].value or array_ops.shape(inputs)[1], 1), - fn1=_single_seq_fn, - fn2=_multi_seq_fn) + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) def crf_log_norm(inputs, sequence_lengths, transition_params): @@ -166,8 +166,8 @@ def crf_log_likelihood(inputs, sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix, if available. Returns: - log_likelihood: A scalar containing the log-likelihood of the given sequence - of tag indices. + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function. """ @@ -182,7 +182,7 @@ def crf_log_likelihood(inputs, transition_params) log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) - # Normalize the scores to get the log-likelihood. + # Normalize the scores to get the log-likelihood per example. log_likelihood = sequence_scores - log_norm return log_likelihood, transition_params @@ -511,7 +511,7 @@ def crf_decode(potentials, transition_params, sequence_length): return decode_tags, best_score return utils.smart_cond( - pred=math_ops.equal( - potentials.shape[1].value or array_ops.shape(potentials)[1], 1), - fn1=_single_seq_fn, - fn2=_multi_seq_fn) + pred=math_ops.equal(potentials.shape[1].value or + array_ops.shape(potentials)[1], 1), + true_fn=_single_seq_fn, + false_fn=_multi_seq_fn) diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 56471911c5c0d1c1825955c67997b5bbc0786463..9bd6a42da2d93263e84a759cffdc5a9e8f9742fd 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -28,11 +28,33 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "threadpool_dataset_op", + srcs = ["threadpool_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + +cc_library( + name = "unique_dataset_op", + srcs = ["unique_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + cc_library( name = "dataset_kernels", deps = [ ":ignore_errors_dataset_op", ":prefetching_kernels", + ":threadpool_dataset_op", + ":unique_dataset_op", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4b3edde85fc755f1c7694a555b867317e81f149d --- /dev/null +++ b/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/threadpool.h" + +namespace tensorflow { +namespace { + +class ThreadPoolResource : public ResourceBase { + public: + ThreadPoolResource(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads, bool low_latency_hint) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint) { + } + + // Schedules fn() for execution in the pool of threads. + void Schedule(std::function fn) { + thread_pool_.Schedule(std::move(fn)); + } + + string DebugString() override { return "ThreadPoolResource"; } + + private: + thread::ThreadPool thread_pool_; +}; + +// Creates a handle to a ThreadPool resource. Note that we don't use +// ResourceOpKernel here because the ThreadPoolResource constructor requires +// access to `OpKernelContext::env()`, which isn't provided by +// `ResourceOpKernel::CreateResource()`. +class ThreadPoolHandleOp : public OpKernel { + public: + explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES( + ctx, num_threads_ > 0, + errors::InvalidArgument("`num_threads` must be greater than zero.")); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ThreadPoolHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + ThreadPoolResource* resource; + OP_REQUIRES_OK(ctx, mgr->LookupOrCreate( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](ThreadPoolResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new ThreadPoolResource( + ctx->env(), {}, display_name_, + num_threads_, + false /* low_latency_hint */); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + string display_name_; + int num_threads_; +}; + +class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + ThreadPoolResource* threadpool_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &threadpool_resource)); + core::ScopedUnref unref_iterator(threadpool_resource); + + *output = new Dataset(ctx, input, threadpool_resource); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + ThreadPoolResource* threadpool) + : GraphDatasetBase(ctx), input_(input), threadpool_(threadpool) { + input_->Ref(); + threadpool_->Ref(); + } + + ~Dataset() override { + input_->Unref(); + threadpool_->Unref(); + } + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() override { return "ThreadPoolDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented( + "Cannot currently serialize the thread pool for a " + "ThreadPoolDataset."); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params), + input_impl_(params.dataset->input_->MakeIterator(params.prefix)) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + ThreadPoolResource* pool = dataset()->threadpool_; + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = [pool](std::function c) { + pool->Schedule(std::move(c)); + }; + params.stats_aggregator_getter = [ctx]() { + return ctx->stats_aggregator(); + }; + params.lib = ctx->lib(); + params.function_library = ctx->function_library(); + params.allocator_getter = [ctx](AllocatorAttributes attrs) { + return ctx->allocator(attrs); + }; + IteratorContext threadpool_ctx(params); + return input_impl_->GetNext(&threadpool_ctx, out_tensors, + end_of_sequence); + } + + private: + std::unique_ptr input_impl_; + }; + + const DatasetBase* const input_; + ThreadPoolResource* const threadpool_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU), + ThreadPoolHandleOp); +REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU), + ThreadPoolDatasetOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unique_dataset_op.cc b/tensorflow/contrib/data/kernels/unique_dataset_op.cc similarity index 99% rename from tensorflow/core/kernels/data/unique_dataset_op.cc rename to tensorflow/contrib/data/kernels/unique_dataset_op.cc index 7726ee0edf71b34cb65fe5fceb2b60dd30bb58e2..69fbb0fcdcce87951d2c9b84210fda378081b103 100644 --- a/tensorflow/core/kernels/data/unique_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/unique_dataset_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/hash/hash.h" namespace tensorflow { diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc index 289ffa1d9c29092cdf434e86ed5553ff9644d43e..a4c1212da11a2410461a120ed5f7116e80e4b903 100644 --- a/tensorflow/contrib/data/ops/dataset_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -27,6 +27,16 @@ REGISTER_OP("IgnoreErrorsDataset") Creates a dataset that contains the elements of `input_dataset` ignoring errors. )doc"); +REGISTER_OP("UniqueDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the unique elements of `input_dataset`. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") @@ -65,4 +75,33 @@ output: A list of return values. output_types: The type list for the return values. )doc"); +REGISTER_OP("ThreadPoolDataset") + .Input("input_dataset: variant") + .Input("thread_pool: resource") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that uses a custom thread pool to compute `input_dataset`. + +handle: A resource produced by the ThreadPoolHandle op. +)doc"); + +REGISTER_OP("ThreadPoolHandle") + .Output("handle: resource") + .SetShapeFn(shape_inference::ScalarShape) + .Attr("num_threads: int") + .Attr("display_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''") + .Doc(R"doc( +Creates a custom thread pool with the given number of threads. + +handle: A resource that can be consumed by one or more ThreadPoolDataset ops. +num_threads: The number of threads in the thread pool. +display_name: A human-readable name for the threads that may be visible in + some visualizations. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index e51d57cc896dc32d8e11912cd89f34a04a858c78..82cd276ce8073b1e66bbc620fa845733aaaca4d4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -419,6 +419,20 @@ py_test( ], ) +py_test( + name = "threadpool_dataset_ops_test", + size = "small", + srcs = ["threadpool_dataset_ops_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + ], +) + py_test( name = "unique_dataset_op_test", size = "small", diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index efd864f866611bfd3bac1edcf98d84be852410fd..e26cef8ec522c7e69a0c19b2b30a969bbfc0ad78 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os + import sqlite3 from tensorflow.contrib.data.python.ops import readers diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9167cb3379bba5cb1ba76a96549395c45dca9e35 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py @@ -0,0 +1,77 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline statistics gathering ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +import numpy as np + +from tensorflow.contrib.data.python.ops import threadpool +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import script_ops +from tensorflow.python.platform import test + + +class OverrideThreadpoolDatasetTest(test.TestCase): + + def testNumThreads(self): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + for num_threads in [1, 2, 4, 8, 16]: + + dataset = ( + dataset_ops.Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, display_name="private_thread_pool_%d" % num_threads)) + + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + sess.run(iterator.initializer) + thread_ids = [] + try: + while True: + thread_ids.append(sess.run(next_element)) + except errors.OutOfRangeError: + pass + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index b488357f226d0922bba3799cc1f4b5c75e2e8328..789cb9c99a6bba06a1e3bd3371d1378065f49f46 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -105,6 +105,7 @@ py_library( "resampling.py", "scan_ops.py", "stats_ops.py", + "threadpool.py", "unique.py", ], srcs_version = "PY2AND3", @@ -120,6 +121,7 @@ py_library( "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:util", diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py deleted file mode 100644 index ff15c4451ad987bcd77dbdd022a1c070056c47e1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ /dev/null @@ -1,691 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Python wrappers for Datasets and Iterators.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import error_ops -from tensorflow.contrib.data.python.ops import grouping -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.util import nest -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gen_io_ops -from tensorflow.python.util import deprecation - - -class Dataset(dataset_ops.Dataset): - """Represents a potentially large set of elements. - - A `Dataset` can be used to represent an input pipeline as a - collection of elements (nested structures of tensors) and a "logical - plan" of transformations that act on those elements. - """ - - def __init__(self, dataset): - super(Dataset, self).__init__() - self._dataset = dataset - - @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") - def make_dataset_resource(self): - return self._as_variant_tensor() - - def _as_variant_tensor(self): - return self._dataset._as_variant_tensor() # pylint: disable=protected-access - - @property - def output_classes(self): - return self._dataset.output_classes - - @property - def output_shapes(self): - return self._dataset.output_shapes - - @property - def output_types(self): - return self._dataset.output_types - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") - def from_tensors(tensors): - """Creates a `Dataset` with a single element, comprising the given tensors. - - Args: - tensors: A nested structure of tensors. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TensorDataset(tensors)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") - def from_tensor_slices(tensors): - """Creates a `Dataset` whose elements are slices of the given tensors. - - Args: - tensors: A nested structure of tensors, each having the same size in the - 0th dimension. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TensorSliceDataset(tensors)) - - @staticmethod - @deprecation.deprecated(None, - "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") - def from_sparse_tensor_slices(sparse_tensor): - """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. - - Args: - sparse_tensor: A `tf.SparseTensor`. - - Returns: - A `Dataset` of rank-(N-1) sparse tensors. - """ - return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") - def from_generator(generator, output_types, output_shapes=None): - """Creates a `Dataset` whose elements are generated by `generator`. - - The `generator` argument must be a callable object that returns - an object that support the `iter()` protocol (e.g. a generator function). - The elements generated by `generator` must be compatible with the given - `output_types` and (optional) `output_shapes` arguments. - - For example: - - ```python - import itertools - - def gen(): - for i in itertools.count(1): - yield (i, [1] * i) - - ds = Dataset.from_generator( - gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None]))) - value = ds.make_one_shot_iterator().get_next() - - sess.run(value) # (1, array([1])) - sess.run(value) # (2, array([1, 1])) - ``` - - Args: - generator: A callable object that takes no arguments and returns an - object that supports the `iter()` protocol. - output_types: A nested structure of `tf.DType` objects corresponding to - each component of an element yielded by `generator`. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` - objects corresponding to each component of an element yielded by - `generator`. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.Dataset.from_generator( - generator, output_types, output_shapes)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") - def range(*args): - """Creates a `Dataset` of a step-separated range of values. - - For example: - - ```python - Dataset.range(5) == [0, 1, 2, 3, 4] - Dataset.range(2, 5) == [2, 3, 4] - Dataset.range(1, 5, 2) == [1, 3] - Dataset.range(1, 5, -2) == [] - Dataset.range(5, 1) == [] - Dataset.range(5, 1, -2) == [5, 3] - ``` - - Args: - *args: follow same semantics as python's xrange. - len(args) == 1 -> start = 0, stop = args[0], step = 1 - len(args) == 2 -> start = args[0], stop = args[1], step = 1 - len(args) == 3 -> start = args[0], stop = args[1, stop = args[2] - - Returns: - A `RangeDataset`. - - Raises: - ValueError: if len(args) == 0. - """ - return Dataset(dataset_ops.RangeDataset(*args)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") - def zip(datasets): - """Creates a `Dataset` by zipping together the given datasets. - - This method has similar semantics to the built-in `zip()` function - in Python, with the main difference being that the `datasets` - argument can be an arbitrary nested structure of `Dataset` objects. - For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { 4, 5, 6 } - c = { (7, 8), (9, 10), (11, 12) } - d = { 13, 14 } - - # The nested structure of the `datasets` argument determines the - # structure of elements in the resulting dataset. - Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) } - Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) } - - # The `datasets` argument may contain an arbitrary number of - # datasets. - Dataset.zip((a, b, c)) == { (1, 4, (7, 8)), - (2, 5, (9, 10)), - (3, 6, (11, 12)) } - - # The number of elements in the resulting dataset is the same as - # the size of the smallest dataset in `datasets`. - Dataset.zip((a, d)) == { (1, 13), (2, 14) } - ``` - - Args: - datasets: A nested structure of datasets. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ZipDataset(datasets)) - - def concatenate(self, dataset): - """Creates a `Dataset` by concatenating given dataset with this dataset. - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { 4, 5, 6, 7 } - - # Input dataset and dataset to be concatenated should have same - # nested structures and output types. - # c = { (8, 9), (10, 11), (12, 13) } - # d = { 14.0, 15.0, 16.0 } - # a.concatenate(c) and a.concatenate(d) would result in error. - - a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 } - ``` - - Args: - dataset: `Dataset` to be concatenated. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ConcatenateDataset(self._dataset, dataset)) - - def prefetch(self, buffer_size): - """Creates a `Dataset` that prefetches elements from this dataset. - - Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - maximum number elements that will be buffered when prefetching. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) - - @staticmethod - @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") - def list_files(file_pattern): - """A dataset of all files matching a pattern. - - Example: - If we had the following files on our filesystem: - - /path/to/dir/a.txt - - /path/to/dir/b.py - - /path/to/dir/c.py - If we pass "/path/to/dir/*.py" as the directory, the dataset would - produce: - - /path/to/dir/b.py - - /path/to/dir/c.py - - Args: - file_pattern: A string or scalar string `tf.Tensor`, representing - the filename pattern that will be matched. - - Returns: - A `Dataset` of strings corresponding to file names. - """ - return Dataset.from_tensor_slices(gen_io_ops.matching_files(file_pattern)) - - def repeat(self, count=None): - """Repeats this dataset `count` times. - - Args: - count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - number of times the elements of this dataset should be repeated. The - default behavior (if `count` is `None` or `-1`) is for the elements to - be repeated indefinitely. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") - def enumerate(self, start=0): - """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" - - return self.apply(enumerate_ops.enumerate_dataset(start)) - - def shuffle(self, buffer_size, seed=None): - """Randomly shuffles the elements of this dataset. - - Args: - buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the - number of elements from this dataset from which the new - dataset will sample. - seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the - random seed that will be used to create the distribution. See - @{tf.set_random_seed} for behavior. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.ShuffleDataset(self._dataset, buffer_size, seed)) - - def cache(self, filename=""): - """Caches the elements in this dataset. - - Args: - filename: A `tf.string` scalar `tf.Tensor`, representing the name of a - directory on the filesystem to use for caching tensors in this Dataset. - If a filename is not provided, the dataset will be cached in memory. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.CacheDataset(self._dataset, filename)) - - def take(self, count): - """Creates a `Dataset` with at most `count` elements from this dataset. - - Args: - count: A `tf.int64` scalar `tf.Tensor`, representing the number of - elements of this dataset that should be taken to form the new dataset. - If `count` is -1, or if `count` is greater than the size of this - dataset, the new dataset will contain all elements of this dataset. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.TakeDataset(self._dataset, count)) - - def skip(self, count): - """Creates a `Dataset` that skips `count` elements from this dataset. - - Args: - count: A `tf.int64` scalar `tf.Tensor`, representing the number - of elements of this dataset that should be skipped to form the - new dataset. If `count` is greater than the size of this - dataset, the new dataset will contain no elements. If `count` - is -1, skips the entire dataset. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.SkipDataset(self._dataset, count)) - - def shard(self, num_shards, index): - """Creates a `Dataset` that includes only 1/`num_shards` of this dataset. - - This dataset operator is very useful when running distributed training, as - it allows each worker to read a unique subset. - - When reading a single input file, you can skip elements as follows: - - ```python - d = tf.data.TFRecordDataset(FLAGS.input_file) - d = d.shard(FLAGS.num_workers, FLAGS.worker_index) - d = d.repeat(FLAGS.num_epochs) - d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) - ``` - - Important caveats: - - - Be sure to shard before you use any randomizing operator (such as - shuffle). - - Generally it is best if the shard operator is used early in the dataset - pipeline. For example, when reading from a set of TFRecord files, shard - before converting the dataset to input samples. This avoids reading every - file on every worker. The following is an example of an efficient - sharding strategy within a complete pipeline: - - ```python - d = tf.data.Dataset.list_files(FLAGS.pattern) - d = d.shard(FLAGS.num_workers, FLAGS.worker_index) - d = d.repeat(FLAGS.num_epochs) - d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.interleave(tf.data.TFRecordDataset, - cycle_length=FLAGS.num_readers, block_length=1) - d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) - ``` - - Args: - num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of - shards operating in parallel. - index: A `tf.int64` scalar `tf.Tensor`, representing the worker index. - - Returns: - A `Dataset`. - - Raises: - ValueError: if `num_shards` or `index` are illegal values. Note: error - checking is done on a best-effort basis, and aren't guaranteed to be - caught upon dataset creation. (e.g. providing in a placeholder tensor - bypasses the early checking, and will instead result in an error during - a session.run call.) - """ - return Dataset(self._dataset.shard(num_shards, index)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.") - def ignore_errors(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" - - return self.apply(error_ops.ignore_errors()) - - def batch(self, batch_size): - """Combines consecutive elements of this dataset into batches. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.BatchDataset(self._dataset, batch_size)) - - def padded_batch(self, batch_size, padded_shapes, padding_values=None): - """Combines consecutive elements of this dataset into padded batches. - - Like `Dataset.dense_to_sparse_batch()`, this method combines - multiple consecutive elements of this dataset, which might have - different shapes, into a single element. The tensors in the - resulting element have an additional outer dimension, and are - padded to the respective shape in `padded_shapes`. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - padded_shapes: A nested structure of `tf.TensorShape` or - `tf.int64` vector tensor-like objects representing the shape - to which the respective component of each input element should - be padded prior to batching. Any unknown dimensions - (e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a - tensor-like object) will be padded to the maximum size of that - dimension in each batch. - padding_values: (Optional.) A nested structure of scalar-shaped - `tf.Tensor`, representing the padding values to use for the - respective components. Defaults are `0` for numeric types and - the empty string for string types. - - Returns: - A `Dataset`. - """ - return Dataset( - dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, - padding_values)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") - def dense_to_sparse_batch(self, batch_size, row_shape): - """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" - - return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) - - @deprecation.deprecated( - None, "Use `ds.apply(tf.contrib.data.group_by_window())`.") - def group_by_window(self, key_func, reduce_func, window_size): - """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" - - return self.apply( - grouping.group_by_window(key_func, reduce_func, window_size)) - - @deprecation.deprecated_args( - None, - "Replace `num_threads=T` with `num_parallel_calls=T`. Replace " - "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", - "num_threads", "output_buffer_size") - def map(self, - map_func, - num_threads=None, - output_buffer_size=None, - num_parallel_calls=None): - """Maps `map_func` across this dataset. - - Args: - map_func: A function mapping a nested structure of tensors (having - shapes and types defined by `self.output_shapes` and - `self.output_types`) to another nested structure of tensors. - num_threads: (Optional.) Deprecated, use `num_parallel_calls` instead. - output_buffer_size: (Optional.) A `tf.int64` scalar `tf.Tensor`, - representing the maximum number of processed elements that will be - buffered. - num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`, - representing the number elements to process in parallel. If not - specified, elements will be processed sequentially. - - Returns: - A `Dataset`. - """ - if num_threads is None and num_parallel_calls is None: - ret = Dataset(dataset_ops.MapDataset(self._dataset, map_func)) - else: - if num_threads is None: - ret = Dataset( - dataset_ops.ParallelMapDataset(self._dataset, map_func, - num_parallel_calls)) - else: - ret = Dataset( - dataset_ops.ParallelMapDataset(self._dataset, map_func, - num_threads)) - if output_buffer_size is not None: - ret = ret.prefetch(output_buffer_size) - return ret - - def flat_map(self, map_func): - """Maps `map_func` across this dataset and flattens the result. - - Args: - map_func: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - `Dataset`. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.FlatMapDataset(self._dataset, map_func)) - - def interleave(self, map_func, cycle_length, block_length=1): - """Maps `map_func` across this dataset, and interleaves the results. - - For example, you can use `Dataset.interleave()` to process many input files - concurrently: - - ```python - # Preprocess 4 files concurrently, and interleave blocks of 16 records from - # each file. - filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...] - dataset = (Dataset.from_tensor_slices(filenames) - .interleave(lambda x: - TextLineDataset(x).map(parse_fn, num_parallel_calls=1), - cycle_length=4, block_length=16)) - ``` - - The `cycle_length` and `block_length` arguments control the order in which - elements are produced. `cycle_length` controls the number of input elements - that are processed concurrently. If you set `cycle_length` to 1, this - transformation will handle one input element at a time, and will produce - identical results = to @{tf.data.Dataset.flat_map}. In general, - this transformation will apply `map_func` to `cycle_length` input elements, - open iterators on the returned `Dataset` objects, and cycle through them - producing `block_length` consecutive elements from each iterator, and - consuming the next input element each time it reaches the end of an - iterator. - - For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3, 4, 5 } - - # NOTE: New lines indicate "block" boundaries. - a.interleave(lambda x: Dataset.from_tensors(x).repeat(6), - cycle_length=2, block_length=4) == { - 1, 1, 1, 1, - 2, 2, 2, 2, - 1, 1, - 2, 2, - 3, 3, 3, 3, - 4, 4, 4, 4, - 3, 3, - 4, 4, - 5, 5, 5, 5, - 5, 5, - } - ``` - - NOTE: The order of elements yielded by this transformation is - deterministic, as long as `map_func` is a pure function. If - `map_func` contains any stateful operations, the order in which - that state is accessed is undefined. - - Args: - map_func: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - `Dataset`. - cycle_length: The number of elements from this dataset that will be - processed concurrently. - block_length: The number of consecutive elements to produce from each - input element before cycling to another input element. - - Returns: - A `Dataset`. - """ - return Dataset( - dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, - block_length)) - - @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") - def unbatch(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" - - return self.apply(batching.unbatch()) - - def filter(self, predicate): - """Filters this dataset according to `predicate`. - - Args: - predicate: A function mapping a nested structure of tensors (having shapes - and types defined by `self.output_shapes` and `self.output_types`) to a - scalar `tf.bool` tensor. - - Returns: - A `Dataset`. - """ - return Dataset(dataset_ops.FilterDataset(self._dataset, predicate)) - - def apply(self, transformation_func): - """Apply a transformation function to this dataset. - - `apply` enables chaining of custom `Dataset` transformations, which are - represented as functions that take one `Dataset` argument and return a - transformed `Dataset`. - - For example: - - ``` - dataset = (dataset.map(lambda x: x ** 2) - .(group_by_window(key_func, reduce_func, window_size)) - .map(lambda x: x ** 3)) - ``` - - Args: - transformation_func: A function that takes one `Dataset` argument and - returns a `Dataset`. - - Returns: - The `Dataset` returned by applying `transformation_func` to this dataset. - """ - dataset = transformation_func(self) - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`transformation_func` must return a Dataset.") - return Dataset(dataset) - - -def get_single_element(dataset): - """Returns the single element in `dataset` as a nested structure of tensors. - - This function enables you to use a @{tf.data.Dataset} in a stateless - "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. - This can be useful when your preprocessing transformations are expressed - as a `Dataset`, and you want to use the transformation at serving time. - For example: - - ```python - input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) - - def preprocessing_fn(input_str): - # ... - return image, label - - dataset = (tf.data.Dataset.from_tensor_slices(input_batch) - .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) - .batch(BATCH_SIZE)) - - image_batch, label_batch = tf.contrib.data.get_single_element(dataset) - ``` - - Args: - dataset: A @{tf.data.Dataset} object containing a single element. - - Returns: - A nested structure of @{tf.Tensor} objects, corresponding to the single - element of `dataset`. - - Raises: - TypeError: if `dataset` is not a `tf.data.Dataset` object. - InvalidArgumentError (at runtime): if `dataset` does not contain exactly - one element. - """ - if not isinstance(dataset, dataset_ops.Dataset): - raise TypeError("`dataset` must be a `tf.data.Dataset` object.") - return nest.pack_sequence_as( - dataset.output_types, - gen_dataset_ops.dataset_to_single_element( - dataset._as_variant_tensor(), # pylint: disable=protected-access - output_types=nest.flatten(dataset.output_types), - output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py new file mode 100644 index 0000000000000000000000000000000000000000..3f85aa84cd53fcf5e21480aac96e067766ad1b65 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/threadpool.py @@ -0,0 +1,102 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental API for controlling threading in `tf.data` pipelines.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import threading + +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.eager import context +from tensorflow.python.ops import resource_variable_ops + +_uid_counter = 0 +_uid_lock = threading.Lock() + + +def _generate_shared_name(prefix): + with _uid_lock: + global _uid_counter + uid = _uid_counter + _uid_counter += 1 + return "{}{}".format(prefix, uid) + + +class PrivateThreadPool(object): + """A stateful resource that represents a private thread pool.""" + + def __init__(self, num_threads, display_name=None): + """Creates a `PrivateThreadPool` with the given number of threads.""" + if context.in_eager_mode(): + shared_name = _generate_shared_name("privatethreadpool") + self._resource = gen_dataset_ops.thread_pool_handle( + num_threads=num_threads, + display_name=display_name, + shared_name=shared_name) + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=context.context().device_name) + else: + self._resource = gen_dataset_ops.thread_pool_handle( + num_threads=num_threads, display_name=display_name) + + +class _ThreadPoolDataset(dataset_ops.Dataset): + """A `Dataset` that acts as an identity, and sets a custom threadpool.""" + + def __init__(self, input_dataset, thread_pool): + super(_ThreadPoolDataset, self).__init__() + self._input_dataset = input_dataset + self._thread_pool = thread_pool + + def _as_variant_tensor(self): + return gen_dataset_ops.thread_pool_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._thread_pool._resource, # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + @property + def output_classes(self): + return self._input_dataset.output_classes + + +def override_threadpool(dataset, thread_pool): + """Returns a new dataset that uses the given thread pool for its operations. + + Args: + dataset: A `tf.data.Dataset` object. + thread_pool: A `PrivateThreadPool` object. + + Returns: + A dataset containing the same values as `dataset`, but which uses + `thread_pool` to compute any of its parallel operations (such as + @{tf.data.Dataset.map}). + """ + return _ThreadPoolDataset(dataset, thread_pool) diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py index 133e17d20d0fc4c8d52cef3c95c132374e927a0b..765ef3f9b6d42c9d7af3ce4916731d37d65c9260 100644 --- a/tensorflow/contrib/data/python/ops/unique.py +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -17,11 +17,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes -from tensorflow.python.ops import gen_dataset_ops def unique(): diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 7f510c42215f48a9e795eb81bd9f66b0a2108335..ed79ef70f829f9b72fa67026a5f7a0928130e95b 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -251,6 +251,21 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_test", + srcs = ["python/kernel_tests/kumaraswamy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "moving_stats_test", size = "small", @@ -403,7 +418,7 @@ cuda_py_test( cuda_py_test( name = "poisson_lognormal_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/poisson_lognormal_test.py"], additional_deps = [ ":distributions_py", @@ -915,6 +930,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "kumaraswamy_bijector_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/kumaraswamy_bijector_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "masked_autoregressive_test", size = "small", @@ -984,7 +1018,7 @@ cuda_py_test( cuda_py_test( name = "reshape_test", - size = "small", + size = "medium", srcs = ["python/kernel_tests/bijectors/reshape_test.py"], additional_deps = [ ":bijectors_py", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 837af20ade965f7cd80064d93a2aaa05f5b68f32..61c411271d0bb8d7b4cc3b14992b82ec1e5674ed 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -40,6 +40,7 @@ from tensorflow.contrib.distributions.python.ops.geometric import * from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * +from tensorflow.contrib.distributions.python.ops.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.mixture import * from tensorflow.contrib.distributions.python.ops.mixture_same_family import * @@ -114,6 +115,7 @@ _allowed_symbols = [ 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', + 'Kumaraswamy', 'Laplace', 'LaplaceWithSoftplusScale', 'Logistic', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ad11d9f2484c4b08c67c5f82aec1320475d1d983 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/kumaraswamy_bijector_test.py @@ -0,0 +1,80 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Kumaraswamy Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import Kumaraswamy +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class KumaraswamyBijectorTest(test.TestCase): + """Tests correctness of the Kumaraswamy bijector.""" + + def testBijector(self): + with self.test_session(): + a = 2. + b = 0.3 + bijector = Kumaraswamy( + concentration1=a, concentration0=b, + event_ndims=0, validate_args=True) + self.assertEqual("kumaraswamy", bijector.name) + x = np.array([[[0.1], [0.2], [0.3], [0.4], [0.5]]], dtype=np.float32) + # Kumaraswamy cdf. This is the same as inverse(x). + y = 1. - (1. - x ** a) ** b + self.assertAllClose(y, bijector.inverse(x).eval()) + self.assertAllClose(x, bijector.forward(y).eval()) + kumaraswamy_log_pdf = (np.log(a) + np.log(b) + (a - 1) * np.log(x) + + (b - 1) * np.log1p(-x ** a)) + + self.assertAllClose( + # We should lose a dimension from calculating the determinant of the + # jacobian. + kumaraswamy_log_pdf, + bijector.inverse_log_det_jacobian(x).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(x).eval(), + bijector.forward_log_det_jacobian(y).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Kumaraswamy(concentration1=0.5, concentration0=1.1), + lower_x=0., upper_x=1., n=int(10e3), rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + concentration1 = 1.2 + concentration0 = 2. + bijector = Kumaraswamy( + concentration1=concentration1, + concentration0=concentration0, validate_args=True) + # Omitting the endpoints 0 and 1, since idlj will be inifinity at these + # endpoints. + y = np.linspace(.01, 0.99, num=10).astype(np.float32) + x = 1 - (1 - y ** concentration1) ** concentration0 + assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py index ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2..2980e2bfe93b2e2aa01d38fc9fa4650a015efc06 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -130,10 +130,8 @@ class KumaraswamyTest(test.TestCase): dist.prob([.1, .3, .6]).eval() dist.prob([.2, .3, .5]).eval() # Either condition can trigger. - with self.assertRaisesOpError("sample must be positive"): + with self.assertRaisesOpError("sample must be non-negative"): dist.prob([-1., 0.1, 0.5]).eval() - with self.assertRaisesOpError("sample must be positive"): - dist.prob([0., 0.1, 0.5]).eval() with self.assertRaisesOpError("sample must be no larger than `1`"): dist.prob([.1, .2, 1.2]).eval() @@ -249,13 +247,13 @@ class KumaraswamyTest(test.TestCase): a = np.array([1., 2, 3]) b = np.array([2., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration1 <= 1."): dist.mode().eval() a = np.array([2., 2, 3]) b = np.array([1., 4, 1.2]) dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) - with self.assertRaisesOpError("Condition x < y.*"): + with self.assertRaisesOpError("Mode undefined for concentration0 <= 1."): dist.mode().eval() def testKumaraswamyModeEnableAllowNanStats(self): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index d9c9008417cdb20b62390630cf887d3bd888a0d3..19a7472d91758a2dbd00c4d918853d7bae33685d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function import numpy as np +from scipy import special from scipy import stats from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import constant_op @@ -110,7 +111,7 @@ class PoissonTest(test.TestCase): batch_size = 6 lam = constant_op.constant([3.0] * batch_size) lam_v = 3.0 - x = [2.2, 3.1, 4., 5.5, 6., 7.] + x = [2., 3., 4., 5., 6., 7.] poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) @@ -121,12 +122,31 @@ class PoissonTest(test.TestCase): self.assertEqual(cdf.get_shape(), (6,)) self.assertAllClose(cdf.eval(), stats.poisson.cdf(x, lam_v)) + def testPoissonCDFNonIntegerValues(self): + with self.test_session(): + batch_size = 6 + lam = constant_op.constant([3.0] * batch_size) + lam_v = 3.0 + x = np.array([2.2, 3.1, 4., 5.5, 6., 7.], dtype=np.float32) + + poisson = self._make_poisson(rate=lam) + cdf = poisson.cdf(x) + self.assertEqual(cdf.get_shape(), (6,)) + + # The Poisson CDF should be valid on these non-integer values, and + # equal to igammac(1 + x, rate). + self.assertAllClose(cdf.eval(), special.gammaincc(1. + x, lam_v)) + + with self.assertRaisesOpError("cannot contain fractional components"): + poisson_validate = self._make_poisson(rate=lam, validate_args=True) + poisson_validate.cdf(x).eval() + def testPoissonCdfMultidimensional(self): with self.test_session(): batch_size = 6 lam = constant_op.constant([[2.0, 4.0, 5.0]] * batch_size) lam_v = [2.0, 4.0, 5.0] - x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T + x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 93923c3f083c7f5136b55e9021cbd6323684b976..9437f56b1ebc76165edec224928baeb836277163 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -26,6 +26,7 @@ @@Identity @@Inline @@Invert +@@Kumaraswamy @@MaskedAutoregressiveFlow @@Permute @@PowerTransform @@ -59,6 +60,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.exp import * from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..f5de052c9ed18b1ebf4c174aeea3a951b1ddcd9d --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/kumaraswamy.py @@ -0,0 +1,153 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kumaraswamy bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Kumaraswamy", +] + + +class Kumaraswamy(bijector.Bijector): + """Compute `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a), X in [0, 1]`. + + This bijector maps inputs from `[0, 1]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the [Kumaraswamy distribution]( + https://en.wikipedia.org/wiki/Kumaraswamy_distribution): + + ```none + Y ~ Kumaraswamy(a, b) + pdf(y; a, b, 0 <= y <= 1) = a * b * y ** (a - 1) * (1 - y**a) ** (b - 1) + ``` + """ + + def __init__(self, + concentration1=None, + concentration0=None, + event_ndims=0, + validate_args=False, + name="kumaraswamy"): + """Instantiates the `Kumaraswamy` bijector. + + Args: + concentration1: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `a` is + `concentration1`. + concentration0: Python `float` scalar indicating the transform power, + i.e., `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)` where `b` is + `concentration0`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init", values=[concentration1, concentration0]): + concentration1 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration1, name="concentration1"), + validate_args=validate_args) + concentration0 = self._maybe_assert_valid_concentration( + ops.convert_to_tensor(concentration0, name="concentration0"), + validate_args=validate_args) + + self._concentration1 = concentration1 + self._concentration0 = concentration0 + super(Kumaraswamy, self).__init__( + event_ndims=0, + validate_args=validate_args, + name=name) + + @property + def concentration1(self): + """The `a` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration1 + + @property + def concentration0(self): + """The `b` in: `Y = g(X) = (1 - (1 - X)**(1 / b))**(1 / a)`.""" + return self._concentration0 + + def _forward(self, x): + x = self._maybe_assert_valid(x) + return math_ops.exp( + math_ops.log1p(-math_ops.exp(math_ops.log1p(-x) / self.concentration0)) + / self.concentration1) + + def _inverse(self, y): + y = self._maybe_assert_valid(y) + return math_ops.exp(math_ops.log1p( + -(1 - y**self.concentration1)**self.concentration0)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.concentration1) + math_ops.log(self.concentration0) + + (self.concentration1 - 1) * math_ops.log(y) + + (self.concentration0 - 1) * math_ops.log1p(-y**self.concentration1), + axis=event_dims) + + def _maybe_assert_valid_concentration(self, concentration, validate_args): + """Checks the validity of a concentration parameter.""" + if not validate_args: + return concentration + return control_flow_ops.with_dependencies([ + check_ops.assert_positive( + concentration, + message="Concentration parameter must be positive."), + ], concentration) + + def _maybe_assert_valid(self, x): + if not self.validate_args: + return x + return control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + x, + message="sample must be non-negative"), + check_ops.assert_less_equal( + x, array_ops.ones([], self.concentration0.dtype), + message="sample must be no larger than `1`."), + ], x) diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py index 74d5d8773cf3e69a52554c87d656fea2835c8354..120b38db3cf72e8fce56a7e9293cdf25e75784e2 100644 --- a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -20,15 +20,17 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops -from tensorflow.python.ops.distributions import beta from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.distributions import uniform from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -60,7 +62,7 @@ def _harmonic_number(x): @tf_export("distributions.Kumaraswamy") -class Kumaraswamy(beta.Beta): +class Kumaraswamy(transformed_distribution.TransformedDistribution): """Kumaraswamy distribution. The Kumaraswamy distribution is defined over the `(0, 1)` interval using @@ -151,59 +153,32 @@ class Kumaraswamy(beta.Beta): more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. """ + concentration1 = ops.convert_to_tensor( + concentration1, name="concentration1") + concentration0 = ops.convert_to_tensor( + concentration0, name="concentration0") super(Kumaraswamy, self).__init__( - concentration1=concentration1, - concentration0=concentration0, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats, + distribution=uniform.Uniform( + low=array_ops.zeros([], dtype=concentration1.dtype), + high=array_ops.ones([], dtype=concentration1.dtype), + allow_nan_stats=allow_nan_stats), + bijector=bijectors.Kumaraswamy( + concentration1=concentration1, concentration0=concentration0, + validate_args=validate_args), + batch_shape=distribution_util.get_broadcast_shape( + concentration1, concentration0), name=name) self._reparameterization_type = distribution.FULLY_REPARAMETERIZED - def _sample_n(self, n, seed=None): - expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 - expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 - shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) - uniform_sample = random_ops.random_uniform( - shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) - - kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( - 1. / expanded_concentration1) - return kumaraswamy_sample - - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _log_cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return math_ops.log1p(-(1 - x**a)**b) + @property + def concentration1(self): + """Concentration parameter associated with a `1` outcome.""" + return self.bijector.concentration1 - @distribution_util.AppendDocstring(_kumaraswamy_sample_note) - def _cdf(self, x): - a = self.concentration1 - b = self.concentration0 - return 1 - (1 - x**a)**b - - def _survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return (1 - x**a)**b - - def _log_survival_function(self, x): - a = self.concentration1 - b = self.concentration0 - return b * math_ops.log1p(-x**a) - - def _log_unnormalized_prob(self, x): - x = self._maybe_assert_valid_sample(x) - a = self.concentration1 - b = self.concentration0 - return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) - - def _log_normalization(self): - a = self.concentration1 - b = self.concentration0 - return -(math_ops.log(a) + math_ops.log(b)) + @property + def concentration0(self): + """Concentration parameter associated with a `0` outcome.""" + return self.bijector.concentration0 def _entropy(self): a = self.concentration1 @@ -213,10 +188,11 @@ class Kumaraswamy(beta.Beta): def _moment(self, n): """Compute the n'th (uncentered) moment.""" + total_concentration = self.concentration1 + self.concentration0 expanded_concentration1 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration1 + total_concentration, dtype=self.dtype) * self.concentration1 expanded_concentration0 = array_ops.ones_like( - self.total_concentration, dtype=self.dtype) * self.concentration0 + total_concentration, dtype=self.dtype) * self.concentration0 beta_arg0 = 1 + n / expanded_concentration1 beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( @@ -246,13 +222,14 @@ class Kumaraswamy(beta.Beta): name="nan") is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration1.dtype), self.concentration1, message="Mode undefined for concentration1 <= 1."), check_ops.assert_less( - array_ops.ones([], dtype=self.dtype), + array_ops.ones([], dtype=self.concentration0.dtype), self.concentration0, message="Mode undefined for concentration0 <= 1.") ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index e967dcc90d0712ffc346fb61ee67c44a6d9207cb..02e97c0a2fd004c4fa9382d5367af9f5b034a869 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -35,9 +35,15 @@ __all__ = [ _poisson_sample_note = """ -Note that the input value must be a non-negative floating point tensor with -dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only -legal if it is non-negative and its components are equal to integer values. +The Poisson distribution is technically only defined for non-negative integer +values. When `validate_args=False`, non-integral inputs trigger an assertion. + +When `validate_args=False` calculations are otherwise unchanged despite +integral or non-integral inputs. + +When `validate_args=False`, evaluating the pmf at non-integral values, +corresponds to evaluations of an unnormalized distribution, that does not +correspond to evaluations of the cdf. """ @@ -150,10 +156,6 @@ class Poisson(distribution.Distribution): def _cdf(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - else: - # Whether or not x is integer-form, the following is well-defined. - # However, scipy takes the floor, so we do too. - x = math_ops.floor(x) return math_ops.igammac(1. + x, self.rate) def _log_normalization(self): @@ -162,9 +164,6 @@ class Poisson(distribution.Distribution): def _log_unnormalized_prob(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - else: - # For consistency with cdf, we take the floor. - x = math_ops.floor(x) return x * self.log_rate - math_ops.lgamma(1. + x) def _mean(self): diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto index 4f71aec96a2c3edee8a32b4e14584bd56ef3d439..024765acb28726fd102dfbf167f4e780072ce6e7 100644 --- a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -4,9 +4,9 @@ option cc_enable_arenas = true; package tensorflow.contrib.eager; -// Prototype for an addition to BundleHeaderProto which saves extra information -// about the objects which own variables, allowing for more robust checkpoint -// loading into modified programs. +// Prototype format which saves extra information about the objects which own +// variables, allowing for more robust checkpoint loading into modified +// programs. Currently stored in its own entry in a TensorBundle. message CheckpointableObjectGraph { message Object { @@ -18,37 +18,35 @@ message CheckpointableObjectGraph { string local_name = 2; } - message VariableReference { - // A name for the variable which is unique within the object which owns - // it. Does not include a name_scope or variable_scope prefix. - string local_name = 1; - // The full name of the variable. Used to allow name-based loading of - // checkpoints which were saved using an object-based API. + message SerializedTensor { + // A name for the Tensor. Simple variables have only one + // `SerializedTensor` named "VARIABLE_VALUE" by convention. This value may + // be restored on object creation as an optimization. + string name = 1; + // The full name of the variable/tensor, if applicable. Used to allow + // name-based loading of checkpoints which were saved using an + // object-based API. Should match the checkpoint key which would have been + // assigned by tf.train.Saver. string full_name = 2; - // The generated name of the variable in the checkpoint. + // The generated name of the Tensor in the checkpoint. string checkpoint_key = 3; } message SlotVariableReference { - // An index into `CheckpointableObjectGraph.nodes`, indicating the object - // which created the variable that this variable is slotting for. + // An index into `CheckpointableObjectGraph.nodes`, indicating the + // variable object this slot was created for. int32 original_variable_node_id = 1; - // The local name of the variable being slotted for within the object that - // owns it. - string original_variable_local_name = 2; // The name of the slot (e.g. "m"/"v"). - string slot_name = 3; - // The full name of the slot variable. Used to allow name-based loading of - // checkpoints which were saved using an object-based API. - string full_name = 4; - // The generated name of the variable in the checkpoint. - string checkpoint_key = 5; + string slot_name = 2; + // An index into `CheckpointableObjectGraph.nodes`, indicating the + // `Object` with the value of the slot variable. + int32 slot_variable_node_id = 3; } // Objects which this object depends on. repeated ObjectReference children = 1; - // Non-slot variables owned by this object. - repeated VariableReference variables = 2; + // Serialized data specific to this object. + repeated SerializedTensor attributes = 2; // Slot variables owned by this object. repeated SlotVariableReference slot_variables = 3; } diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index cfb38a1d26c41a3923da7c989244a3d53b6a496b..a26ec8513f4b7b9c278edddc95e6acd2523194f2 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -69,6 +69,7 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", @@ -220,18 +221,19 @@ py_test( ) py_library( - name = "checkpointable", - srcs = ["checkpointable.py"], + name = "checkpointable_utils", + srcs = ["checkpointable_utils.py"], srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", "//tensorflow/python:io_ops", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:training", "//tensorflow/python:variable_scope", @@ -240,11 +242,11 @@ py_library( ) py_test( - name = "checkpointable_test", - srcs = ["checkpointable_test.py"], + name = "checkpointable_utils_test", + srcs = ["checkpointable_utils_test.py"], srcs_version = "PY2AND3", deps = [ - ":checkpointable", + ":checkpointable_utils", ":network", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py deleted file mode 100644 index 896b38a7348e1fdd5a13b197e3ee34f5c4c5a22c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable.py +++ /dev/null @@ -1,773 +0,0 @@ -"""An object-local variable management scheme.""" -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections -import re -import weakref - -from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.training import saver as saver_lib -from tensorflow.python.training import slot_creator -from tensorflow.python.training import training - -_CheckpointableReference = collections.namedtuple( - "_CheckpointableReference", - [ - # The local name if explicitly specified, else None. - "name", - # The Checkpointable object being referenced. - "ref" - ]) - -# Validation regular expression for the local names of Checkpointable -# objects. In particular, disallows "/" in names, and reserves dash-prefixed -# names (which are not valid Python identifiers, so we're not restricting the -# __setattr__ syntax that way). -_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9_.][A-Za-z0-9_.-]*$") - -# Keyword for identifying that the next bit of a checkpoint variable name is a -# slot name. May not be the local name of a checkpointable. Checkpoint names for -# slot variables look like: -# -# /<_OPTIMIZER_SLOTS_NAME>// -# -# Where is a full path from the checkpoint root to the -# variable being slotted for. -_OPTIMIZER_SLOTS_NAME = "-OPTIMIZER_SLOT" - - -def _assign_existing_variable(variable_to_restore, value_pointer): - """Set a variable from a _ValuePointer object.""" - base_type = variable_to_restore.dtype.base_dtype - with ops.colocate_with(variable_to_restore): - # TODO(allenl): Handle partitioned variables - value_to_restore, = io_ops.restore_v2( - prefix=value_pointer.save_path, - tensor_names=[value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_type], - name="checkpoint_initializer") - initializer_op = state_ops.assign(variable_to_restore, value_to_restore) - variable_to_restore._initializer_op = initializer_op # pylint:disable=protected-access - if value_pointer.session is not None: - value_pointer.session.run(initializer_op) - - -def _default_getter(name, shape, dtype, initializer=None, - partition_info=None, **kwargs): - """A pared-down version of get_variable which does not reuse variables.""" - dtype = dtypes.as_dtype(dtype) - shape_object = tensor_shape.as_shape(shape) - with ops.init_scope(): - if initializer is None: - initializer, initializing_from_value = ( - variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access - name=name, shape=shape_object, dtype=dtype)) - else: - initializing_from_value = not callable(initializer) - # Same logic as get_variable - if initializing_from_value: - if shape is not None: - raise ValueError("If initializer is a constant, do not specify shape.") - initial_value = initializer - variable_dtype = None - else: - # Instantiate initializer if provided initializer is a type object. - if isinstance(initializer, type(init_ops.Initializer)): - initializer = initializer(dtype=dtype) - def initial_value(): - return initializer( - shape_object.as_list(), dtype=dtype, partition_info=partition_info) - variable_dtype = dtype.base_dtype - return resource_variable_ops.ResourceVariable( - initial_value=initial_value, - name=name, - dtype=variable_dtype, - **kwargs - ) - - -class Checkpointable(object): - """Manages variables and dependencies on other objects. - - To make reliable checkpoints, all `Checkpointable`s on which this object - depends must be registered in the constructor using `track_checkpointable` in - a deterministic order, and if possible they should be named. Variables may be - created using `add_variable` outside of the constructor and in any order, but - only these variables will be saved. - """ - - def __init__(self): - # A list of _CheckpointableReference objects. - self._checkpoint_dependencies = [] - # Maps names -> Checkpointable objects for named dependencies - self._dependency_names = {} - # Set of all tracked Checkpointables - self._already_tracked = set() - self._owned_variables = {} # local name -> variable object - self._deferred_restorations = {} # local name -> _VariableRestoration - # object - - def __setattr__(self, name, value): - """Support self.foo = checkpointable syntax. - - `self.foo = checkpointable` is equivalent to - `self.foo = self.track_checkpointable(checkpointable, name='foo')`. - - No new tracking if `value` is not a `Checkpointable`, or if `value` is - already being tracked (either because of an explicit `track_checkpointable` - or a previous `__setattr__`). - - Args: - name: The name of the property being set. - value: The new value for the property. - """ - # Give child classes (e.g. Network) priority, then track only if the object - # hasn't been added to _already_tracked. - super(Checkpointable, self).__setattr__(name, value) - if (isinstance(value, Checkpointable) - and value not in self._already_tracked): - self.track_checkpointable(value, name=name) - - def add_variable(self, name, shape=None, dtype=dtypes.float32, - initializer=None, **kwargs): - """Create a new variable object to be saved with this `Checkpointable`. - - If the user has requested that this object or another `Checkpointable` which - depends on this object be restored from a checkpoint (deferred loading - before variable object creation), `initializer` may be ignored and the value - from the checkpoint used instead. - - Args: - name: A name for the variable. Must be unique within this object. - shape: The shape of the variable. - dtype: The data type of the variable. - initializer: The initializer to use. Ignored if deferred loading has been - requested. - **kwargs: Passed to the ResourceVariable constructor. - - Returns: - The new variable object. - - Raises: - ValueError: If the variable name is not unique. - RuntimeError: If __init__ has not been called. - """ - if not hasattr(self, "_owned_variables"): - raise RuntimeError("Need to call Checkpointable.__init__ before adding " - "variables.") - if name in self._owned_variables: - raise ValueError( - ("A variable named '%s' already exists in this Checkpointable, but " - "Checkpointable.add_variable called to create another with " - "that name. Variable names must be unique within a Checkpointable " - "object.") % (name,)) - if "getter" in kwargs: - # Allow the getter to be overridden, typically because there is a need for - # compatibility with some other variable creation mechanism. This should - # be relatively uncommon in user code. - getter = kwargs.pop("getter") - else: - getter = _default_getter - deferred_restoration = self._deferred_restorations.pop(name, None) - if deferred_restoration is not None: - dtype = deferred_restoration.value_pointer.dtype - base_type = dtype.base_dtype - # TODO(allenl): Handle partitioned variables here too - with ops.init_scope(): - initializer, = io_ops.restore_v2( - prefix=deferred_restoration.value_pointer.save_path, - tensor_names=[deferred_restoration.value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_type], - name="checkpoint_initializer") - # We need to un-set the shape so get_variable doesn't complain, but we - # also need to set the static shape information on the initializer if - # possible so we don't get a variable with an unknown shape. - initializer.set_shape(shape) - # Un-set shape since we're using a constant initializer - shape = None - - new_variable = getter( - name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) - if deferred_restoration is not None: - if deferred_restoration.value_pointer.session is not None: - deferred_restoration.value_pointer.session.run(new_variable.initializer) - for slot_restoration in deferred_restoration.slot_restorations: - strong_ref = slot_restoration.optimizer_ref() - if strong_ref is None: - # If the optimizer object has been garbage collected, there's no need - # to create the slot variable. - continue - strong_ref._process_slot_restoration( # pylint: disable=protected-access - slot_restoration, new_variable) - self._owned_variables[name] = new_variable - return new_variable - - def track_checkpointable(self, checkpointable, name): - """Declare a dependency on another `Checkpointable` object. - - Indicates that checkpoints for this object should include variables from - `checkpointable`. - - Variables in a checkpoint are mapped to `Checkpointable`s based on names. To - avoid breaking existing checkpoints when modifying a class, neither variable - names nor dependency names (the names passed to `track_checkpointable`) may - change. - - Args: - checkpointable: A `Checkpointable` which this object depends on. - name: A local name for `checkpointable`, used for loading checkpoints into - the correct objects. Python 2 identifiers are valid names, with the - addition of leading numerals, periods anywhere, and non-leading dashes. - Specifically names must match the regular expression - `^[A-Za-z0-9_.][A-Za-z0-9_.-]*$`. - - Returns: - `checkpointable`, for convenience when declaring a dependency and - assigning to a member variable in one statement. - - Raises: - RuntimeError: If __init__ was not called. - TypeError: If `checkpointable` does not inherit from `Checkpointable`. - ValueError: For invalid names. - """ - if not hasattr(self, "_checkpoint_dependencies"): - raise RuntimeError("Need to call Checkpointable.__init__ before calling " - "Checkpointable.track_checkpointable().") - if not isinstance(checkpointable, Checkpointable): - raise TypeError( - ("Checkpointable.track_checkpointable() passed type %s, not a " - "Checkpointable.") % (type(checkpointable),)) - if not _VALID_LOCAL_NAME.match(name): - raise ValueError( - ("Checkpointable names must match the regular expression '%s', but " - "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, - name)) - if (name in self._dependency_names - and self._dependency_names[name] is not checkpointable): - raise ValueError( - ("Called Checkpointable.track_checkpointable() with name='%s', but " - "a Checkpointable with this name is already declared as a " - "dependency. Names must be unique.") % (name,)) - self._dependency_names[name] = checkpointable - self._checkpoint_dependencies.append( - _CheckpointableReference(name=name, ref=checkpointable)) - self._already_tracked.add(checkpointable) - return checkpointable - - def _process_restoration(self, restoration): - """Restore a variable and its slot variables (may be deferred).""" - variable_to_restore = self._owned_variables.get(restoration.name, None) - if variable_to_restore is not None: - # This variable already exists, so just do an assignment for this and any - # slot variables which depend on it. - _assign_existing_variable( - variable_to_restore, value_pointer=restoration.value_pointer) - for slot_restoration in restoration.slot_restorations: - strong_ref = slot_restoration.optimizer_ref() - if strong_ref is None: - continue - strong_ref._process_slot_restoration( # pylint: disable=protected-access - slot_restoration, variable_to_restore) - else: - # Save this restoration for later. This intentionally overwrites any - # previous deferred restorations, since that gives the same semantics as - # direct assignment. - self._deferred_restorations[restoration.name] = restoration - - def _process_slot_restoration(self, slot_restoration, variable): - """Restore a slot variable's value (creating it if necessary).""" - # TODO(allenl): Move this to Optimizer - assert isinstance(self, optimizer_lib.Optimizer) - named_slots = self._slot_dict(slot_restoration.slot_name) - variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access - existing_slot_variable = named_slots.get(variable_key, None) - if existing_slot_variable is None: - base_dtype = slot_restoration.value_pointer.dtype.base_dtype - initializer, = io_ops.restore_v2( - prefix=slot_restoration.value_pointer.save_path, - tensor_names=[slot_restoration.value_pointer.checkpoint_key], - shape_and_slices=[""], - dtypes=[base_dtype], - name="checkpoint_initializer") - new_slot_variable = slot_creator.create_slot(variable, initializer, - slot_restoration.slot_name) - if slot_restoration.value_pointer.session is not None: - slot_restoration.value_pointer.session.run( - new_slot_variable.initializer) - named_slots[variable_key] = new_slot_variable - else: - _assign_existing_variable( - existing_slot_variable, value_pointer=slot_restoration.value_pointer) - - @property - def checkpoint_dependencies(self): - """Other `Checkpointable` objects on which this object depends.""" - return self._checkpoint_dependencies - - -def _breadth_first_checkpointable_traversal(root_checkpointable): - """Find shortest paths to all variables owned by dependencies of root.""" - bfs_sorted = [] - root_checkpointable_reference = _CheckpointableReference( - name=None, ref=root_checkpointable) - to_visit = collections.deque([root_checkpointable_reference]) - path_to_root = {root_checkpointable_reference: ()} - while to_visit: - current_checkpointable = to_visit.popleft() - bfs_sorted.append(current_checkpointable) - for child_checkpointable in ( - current_checkpointable.ref.checkpoint_dependencies): - if child_checkpointable not in path_to_root: - path_to_root[child_checkpointable] = ( - path_to_root[current_checkpointable] + (child_checkpointable,)) - to_visit.append(child_checkpointable) - return bfs_sorted, path_to_root - - -def _object_prefix_from_path(path_to_root): - return "/".join( - (checkpointable.name for checkpointable in path_to_root)) - - -def _escape_variable_name(variable_name): - # We need to support slashes in variable names for compatibility, since this - # naming scheme is being patched in to things like Layer.add_variable where - # slashes were previously accepted. We also want to use slashes to indicate - # edges traversed to reach the variable, so we escape forward slashes in - # variable names. - return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") - - -def _variable_naming_for_object(path_to_root): - """Make a function for naming variables in an object.""" - # Name non-slot variables: - # - # / - # - # is not necessarily unique, but this is fine since we also - # save the graph of `Checkpointable`s with the checkpoint. Even if this path - # no longer exists because of a change in the Python program, we can look up - # the `Checkpointable` which owns the variable in the checkpoint's graph and - # use another path if one still exists. - - object_prefix = _object_prefix_from_path(path_to_root) - if object_prefix: - object_prefix += "/" - - def _name_single_variable(local_name): - """Names a variable within an object.""" - return object_prefix + _escape_variable_name(local_name) - - return _name_single_variable - - -def _slot_variable_naming_for_optimizer(optimizer, path_to_root): - """Make a function for naming slot variables in an optimizer.""" - # Name slot variables: - # - # /<_OPTIMIZER_SLOTS_NAME>// - # - # where is exactly the checkpoint name used for the original - # variable, including the path from the checkpoint root and the local name in - # the object which owns it. Note that we only save slot variables if the - # variable it's slotting for is also being saved. - - optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, - _object_prefix_from_path(path_to_root)) - - def _name_slot_variable(variable_path, slot_name): - """With an optimizer specified, name a slot variable.""" - - if not _VALID_LOCAL_NAME.match(slot_name): - # Slot variable names include the name of the slot. We need to - # validate that part of the name to be sure that the checkpoint name - # is a valid name scope name. - raise ValueError( - ("Could not save slot variables for optimizer %s, because its " - "slot name has invalid characters (got '%s', was expecting it " - "to match the regular expression '%s').") % - (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) - - return variable_path + optimizer_identifier + slot_name - - return _name_slot_variable - - -def _serialize_non_slot_variables(checkpointable_objects, path_to_root, - object_graph_proto): - """Name non-slot variables and add them to `object_graph_proto`.""" - named_variables = {} - non_slot_variables = [] - checkpoint_node_ids = {} - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - checkpoint_node_ids[checkpointable] = checkpoint_id - - for checkpoint_id, checkpointable in enumerate(checkpointable_objects): - naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) - object_proto = object_graph_proto.nodes.add() - for (local_name, owned_variable) in sorted( - checkpointable.ref._owned_variables.items(), # pylint: disable=protected-access - key=lambda x: x[0]): - variable_name = naming_scheme(local_name) - named_variables[variable_name] = owned_variable - non_slot_variables.append(( - variable_name, # The variable's full checkpoint name - owned_variable, # The variable object - local_name, # The variable's local name - checkpoint_id)) # The checkpoint ID of the node which owns this - # variable. - variable_proto = object_proto.variables.add() - variable_proto.local_name = local_name - variable_proto.checkpoint_key = variable_name - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [owned_variable], convert_variable_to_tensor=False) - variable_full_name, = saver_dict.keys() - variable_proto.full_name = variable_full_name - - for child in checkpointable.ref.checkpoint_dependencies: - child_proto = object_proto.children.add() - child_proto.node_id = checkpoint_node_ids[child] - child_proto.local_name = child.name - return named_variables, non_slot_variables - - -def _serialize_slot_variables(checkpointable_objects, path_to_root, - non_slot_variables, object_graph_proto): - """Name slot variables and add them to `object_graph_proto`.""" - named_slot_variables = {} - for optimizer_checkpoint_id, checkpointable_ref in enumerate( - checkpointable_objects): - if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): - optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] - naming_scheme = _slot_variable_naming_for_optimizer( - optimizer=checkpointable_ref.ref, - path_to_root=path_to_root[checkpointable_ref]) - slot_names = checkpointable_ref.ref.get_slot_names() - for (variable_path, original_variable, original_variable_local_name, - original_node_checkpoint_id) in non_slot_variables: - for slot_name in slot_names: - slot_variable = checkpointable_ref.ref.get_slot( - original_variable, slot_name) - if slot_variable is not None: - checkpoint_name = naming_scheme( - variable_path=variable_path, slot_name=slot_name) - named_slot_variables[checkpoint_name] = slot_variable - slot_variable_proto = optimizer_object_proto.slot_variables.add() - slot_variable_proto.slot_name = slot_name - slot_variable_proto.checkpoint_key = checkpoint_name - # Figure out the name-based Saver's name for this variable. - saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( - [slot_variable], convert_variable_to_tensor=False) - slot_variable_full_name, = saver_dict.keys() - slot_variable_proto.full_name = slot_variable_full_name - slot_variable_proto.original_variable_local_name = ( - original_variable_local_name) - slot_variable_proto.original_variable_node_id = ( - original_node_checkpoint_id) - return named_slot_variables - - -# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct -# a root Checkpointable if passed a list of Checkpointables). -def _serialize_object_graph(root_checkpointable): - """Determine checkpoint keys for variables and build a serialized graph. - - Non-slot variables are keyed based on a shortest path from the root saveable - to the object which owns the variable (i.e. the one which called - `Checkpointable.add_variable` to create it). - - Slot variables are keyed based on a shortest path to the variable being - slotted for, a shortest path to their optimizer, and the slot name. - - Args: - root_checkpointable: A `Checkpointable` object whose variables (including - the variables of dependencies, recursively) should be saved. - - Returns: - A tuple of (named_variables, object_graph_proto): - named_variables: A dictionary mapping names to variable objects. - object_graph_proto: A CheckpointableObjectGraph protocol buffer containing - the serialized object graph and variable references. - - Raises: - ValueError: If there are invalid characters in an optimizer's slot names. - """ - checkpointable_objects, path_to_root = ( - _breadth_first_checkpointable_traversal(root_checkpointable)) - object_graph_proto = ( - checkpointable_object_graph_pb2.CheckpointableObjectGraph()) - - # Gather non-slot variables. - named_variables, non_slot_variables = _serialize_non_slot_variables( - checkpointable_objects, path_to_root, object_graph_proto) - - # Gather slot variables which are associated with variables gathered above. - named_slot_variables = _serialize_slot_variables( - checkpointable_objects, path_to_root, non_slot_variables, - object_graph_proto) - - named_variables.update(named_slot_variables) - return named_variables, object_graph_proto - - -def _set_reference(reference_proto_table, key, checkpointable, parent, - object_id_map): - """Record a checkpoint<->object correspondence, with error checking. - - Args: - reference_proto_table: Map from names or numbers to `ObjectReference` protos - within the parent object. - key: Either a numeric or string identifier for the reference. - checkpointable: The object to record a correspondence for. - parent: The parent Python object, for creating a useful error message. - object_id_map: The map from `node_id` to Python object in which to record - the reference. - Returns: - The `node_id` of the Object proto corresponding to the specified Python - object. - Raises: - AssertionError: If another object is already bound to the `Object` proto. - """ - reference_proto = reference_proto_table[key] - set_reference = object_id_map.setdefault(reference_proto.node_id, - checkpointable) - if set_reference is not checkpointable: - raise AssertionError( - ("Unable to load the checkpoint into this object graph. Either " - "the Checkpointable object references in the Python program " - "have changed in an incompatible way, or the checkpoint was " - "generated in an incompatible program.\n\nTwo checkpoint " - "references (one being '%s' in %s) resolved to different " - "objects (%s and %s).") % (key, parent, set_reference, - checkpointable)) - return reference_proto.node_id - - -def _checkpoint_object_id_map(root_checkpointable, object_graph_proto): - """Match a checkpointed object graph to a Python object graph. - - Args: - root_checkpointable: A Checkpointable object. - object_graph_proto: A CheckpointableObjectGraph protocol buffer representing - a serialized object graph. - Returns: - A dictionary mapping from checkpoint node ids (indices into - `object_graph_proto.nodes`) to `Checkpointable` objects which are - dependencies of `root_checkpointable`. - """ - node_list = object_graph_proto.nodes - # Queue of (checkpointable object, node id) - to_visit = collections.deque([(root_checkpointable, 0)]) - object_id_map = {0: root_checkpointable} - seen = set() - while to_visit: - checkpointable, node_id = to_visit.popleft() - object_proto = node_list[node_id] - named_children = {} - for child_reference in object_proto.children: - if child_reference.local_name: - named_children[child_reference.local_name] = child_reference - else: - raise AssertionError( - ("The checkpointed object graph contains a reference without " - "a name (corrupted?). The reference was from the node %s.") - % (object_proto,)) - - for checkpointable_reference in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access - child_node_id = _set_reference( - reference_proto_table=named_children, - key=checkpointable_reference.name, - checkpointable=checkpointable_reference.ref, - parent=checkpointable, - object_id_map=object_id_map) - if child_node_id not in seen: - seen.add(child_node_id) - to_visit.append((checkpointable_reference.ref, child_node_id)) - - return object_id_map - - -_ValuePointer = collections.namedtuple( - "_ValuePointer", - [ - # Information needed to look up the value to restore. - "save_path", - "checkpoint_key", - "dtype", - # The session to use when restoring (None when executing eagerly) - "session", - ]) - -_SlotVariableRestoration = collections.namedtuple( - "_SlotVariableRestoration", - [ - # A weak reference to the Optimizer object - "optimizer_ref", - # The slot name - "slot_name", - # The _ValuePointer to use when restoring - "value_pointer", - ]) - -_VariableRestoration = collections.namedtuple( - "_VariableRestoration", - [ - # The variable's (local) name. - "name", - # _SlotVariableRestoration objects indicating slot variables which - # should be created once this variable has been restored. - "slot_restorations", - # The _ValuePointer to use when restoring - "value_pointer", - ]) - - -def _gather_restorations(object_graph_proto, save_path, object_id_map, - dtype_map, session): - """Iterate over variables to restore, matching with Checkpointable objects.""" - variable_to_slot_restorations = {} - for node_id, node in enumerate(object_graph_proto.nodes): - for slot_variable in node.slot_variables: - original_variable_key = (slot_variable.original_variable_node_id, - slot_variable.original_variable_local_name) - variable_to_slot_restorations.setdefault( - original_variable_key, []).append( - _SlotVariableRestoration( - optimizer_ref=weakref.ref(object_id_map[node_id]), - slot_name=slot_variable.slot_name, - value_pointer=_ValuePointer( - save_path=save_path, - checkpoint_key=slot_variable.checkpoint_key, - dtype=dtype_map[slot_variable.checkpoint_key], - session=session))) - - for node_id, node in enumerate(object_graph_proto.nodes): - for variable in node.variables: - slots_key = (node_id, variable.local_name) - variable_restore = _VariableRestoration( - name=variable.local_name, - slot_restorations=variable_to_slot_restorations.get(slots_key, []), - value_pointer=_ValuePointer( - save_path=save_path, - checkpoint_key=variable.checkpoint_key, - dtype=dtype_map[variable.checkpoint_key], - session=session)) - yield variable_restore, object_id_map[node_id] - - -def save(file_prefix, root_checkpointable, global_step=None, session=None): - """Save a training checkpoint. - - Args: - file_prefix: A prefix to use for the checkpoint filenames - (/path/to/directory/and_a_prefix). Names are generated based on this - prefix and the global step, if provided. - root_checkpointable: A Checkpointable object to save. The checkpoint - includes variables created by this object and any Checkpointable objects - it depends on. - global_step: An integer variable or Tensor, used to number - checkpoints. Typically this value is saved along with other variables in - training checkpoints, which will happen automatically if it was created by - `root_checkpointable` or one of its dependencies (via - `Checkpointable.add_variable`). - session: The session to evaluate variables in. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - - Returns: - The full path to the checkpoint. - - Currently also returns the serialized object graph proto, but that will go - away once it's saved with the checkpoint. - """ - named_variables, serialized_graph = _serialize_object_graph( - root_checkpointable) - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - with ops.device("/device:CPU:0"): - save_path = saver_lib.Saver(var_list=named_variables).save( - sess=session, - save_path=file_prefix, - write_meta_graph=False, - global_step=global_step) - # TODO(allenl): Save the graph with the checkpoint, then returning it and - # taking it as an argument to restore won't be necessary. - return serialized_graph, save_path - - -# NOTE: Will be restore(file_prefix, root_checkpointable) once the object graph -# is saved with the checkpoint. -def restore(save_path, root_checkpointable, object_graph_proto, session=None): - """Restore a training checkpoint. - - Restores the values of variables created with `Checkpointable.add_variable` in - the dependency graph of `root_checkpointable`. Either assigns values - immediately (if variables to restore have been created already), or defers - restoration until the variables are created. - - When building a graph, restorations are executed in the default session if - `session` is `None`. Variable initializers read checkpointed values. - - Args: - save_path: The path to the checkpoint, as returned by `save` or - `tf.train.latest_checkpoint`. If None (as when there is no latest - checkpoint for `tf.train.latest_checkpoint` to return), does nothing. - root_checkpointable: The root of the object graph to restore. Variables to - restore need not have been created yet, but all dependencies on other - Checkpointable objects should already be declared. Objects in the - dependency graph are matched to objects in the checkpointed graph, and - matching objects have their variables restored (or the checkpointed values - saved for eventual restoration when the variable is created). - object_graph_proto: (Temporary) the checkpointed object graph. This will - eventually be saved with the checkpoint, and will not be part of the final - API. - session: The session to evaluate assignment ops in. Ignored when executing - eagerly. If not provided when graph building, the default session is used. - """ - if save_path is None: - return - object_id_map = _checkpoint_object_id_map(root_checkpointable, - object_graph_proto) - reader = training.NewCheckpointReader(save_path) - dtype_map = reader.get_variable_to_dtype_map() - if context.in_graph_mode(): - if session is None: - session = ops.get_default_session() - else: - session = None - for restoration, checkpointable in _gather_restorations( - object_graph_proto, save_path, object_id_map, dtype_map, session=session): - checkpointable._process_restoration(restoration) # pylint: disable=protected-access - diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py deleted file mode 100644 index f7bc155decbb574ddd4b53190da3c3b3ee9b6a4e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/checkpointable_test.py +++ /dev/null @@ -1,497 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools -import os - -import six - -from tensorflow.contrib.eager.python import checkpointable -from tensorflow.contrib.eager.python import network as network_lib -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.layers import base -from tensorflow.python.layers import core -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.training import adam -from tensorflow.python.training import saver as core_saver -from tensorflow.python.training import training_util - - -class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - core.Dense.__init__(self, *args, **kwargs) - - def add_variable(self, name, shape, **kwargs): - # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually - # Layer.add_variable should inherit from Checkpointable and simply call - # super and then do post-processing. - return checkpointable.Checkpointable.add_variable( - self, - name=name, - shape=shape, - getter=functools.partial(core.Dense.add_variable, self), - **kwargs) - - -# pylint: disable=not-callable -class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): - - def __init__(self): - network_lib.Network.__init__(self) - checkpointable.Checkpointable.__init__(self) - - def __setattr__(self, name, value): - if isinstance(value, base.Layer) and value not in self._already_tracked: - self.track_layer(value, name=name) - # Checkpointable is next in the method resolution order, so this will catch - # Checkpointable objects which aren't Layers. - super(CheckpointableNetwork, self).__setattr__(name, value) - - def track_layer(self, layer, name): - self.track_checkpointable(layer, name=name) - return super(CheckpointableNetwork, self).track_layer(layer) - - -class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): - - def __init__(self, *args, **kwargs): - checkpointable.Checkpointable.__init__(self) - adam.AdamOptimizer.__init__(self, *args, **kwargs) - - # NOTE: Copied from Optimizer with modifications to use add_variable - # for non-slot variables. These contortions are necessary to maintain - # checkpoint compatibility with variable.name based saving. - # TODO(allenl): Make this cleaner. - def _create_non_slot_variable(self, initial_value, name, colocate_with): - """Add an extra variable, not associated with a slot.""" - if context.in_graph_mode(): - graph = colocate_with.graph - else: - graph = None - - key = (name, graph) - v = self._non_slot_dict.get(key, None) - if v is None: - with ops.colocate_with(colocate_with): - def _variable_getter(name, shape, dtype, initializer): - del shape, dtype # not used, but there for compatibility - return variable_scope.variable( - name=name, initial_value=initializer, trainable=False) - - initial_value = ops.convert_to_tensor(initial_value) - v = self.add_variable( - name=name, - shape=initial_value.get_shape(), - initializer=initial_value, - getter=_variable_getter) - - self._non_slot_dict[key] = v - - return v - - -class NonLayerCheckpointable(checkpointable.Checkpointable): - - def __init__(self): - super(NonLayerCheckpointable, self).__init__() - self.a_variable = self.add_variable(name="a_variable", shape=[]) - - -class MyNetwork(CheckpointableNetwork): - """A concrete Network for testing.""" - - def __init__(self): - super(MyNetwork, self).__init__() - self._named_dense = CheckpointableDenseLayer(1, use_bias=True) - self._via_track_layer = self.track_layer( - CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") - # We can still track Checkpointables which aren't Layers. - self._non_layer = NonLayerCheckpointable() - - def call(self, values): - return self._via_track_layer(self._named_dense(values)) - - -class Root(checkpointable.Checkpointable): - """A stand-in for a Trainer class.""" - - def __init__(self, optimizer, network): - super(Root, self).__init__() - self._optimizer = optimizer - self._network = self.track_checkpointable(network, "network") - self._global_step = None - - @property - def global_step(self): - if self._global_step is None: - # Get the default create_global_step utility to actually call - # self.add_variable, by setting a custom creator. - def _owned_variable_as_creator( - next_creator, initial_value, **kwargs): - def _creator_as_getter(initializer, **kwargs): - return next_creator(initial_value=initializer, **kwargs) - return self.add_variable( - getter=_creator_as_getter, initializer=initial_value, shape=[], - **kwargs) - - with variable_scope.variable_creator_scope( - _owned_variable_as_creator): - self._global_step = training_util.create_global_step() - return self._global_step - - -class InterfaceTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testAddVariable(self): - obj = NonLayerCheckpointable() - with self.assertRaisesRegexp(ValueError, "do not specify shape"): - obj.add_variable( - name="shape_specified_twice", shape=[], initializer=1) - constant_initializer = obj.add_variable( - name="constant_initializer", initializer=1) - with variable_scope.variable_scope("some_variable_scope"): - ones_initializer = obj.add_variable( - name="ones_initializer", - shape=[2], - initializer=init_ops.ones_initializer(dtype=dtypes.float32)) - bare_initializer = obj.add_variable( - name="bare_initializer", - shape=[2, 2], - dtype=dtypes.float64, - initializer=init_ops.zeros_initializer) - - # Even in graph mode, there are no naming conflicts between objects, only - # naming conflicts within an object. - other_duplicate = resource_variable_ops.ResourceVariable( - name="duplicate", initial_value=1.) - duplicate = obj.add_variable(name="duplicate", shape=[]) - with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): - obj.add_variable(name="duplicate", shape=[]) - - if context.in_graph_mode(): - self.evaluate(variables.global_variables_initializer()) - self.assertEqual("constant_initializer:0", constant_initializer.name) - self.assertEqual(1, self.evaluate(constant_initializer)) - self.assertEqual("some_variable_scope/ones_initializer:0", - ones_initializer.name) - self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) - self.assertAllEqual([[0., 0.], - [0., 0.]], self.evaluate(bare_initializer)) - self.assertEqual("a_variable:0", obj.a_variable.name) - self.assertEqual("duplicate:0", other_duplicate.name) - if context.in_graph_mode(): - # The .name attribute may be globally influenced, but the checkpoint name - # won't be (tested below). - self.assertEqual("duplicate_1:0", duplicate.name) - else: - # When executing eagerly, there's no uniquification of variable names. The - # checkpoint name will be the same. - self.assertEqual("duplicate:0", duplicate.name) - named_variables, _ = checkpointable._serialize_object_graph(obj) - expected_checkpoint_names = ( - "a_variable", - "bare_initializer", - "constant_initializer", - "duplicate", - "ones_initializer", - ) - six.assertCountEqual( - self, expected_checkpoint_names, named_variables.keys()) - - def testInitNotCalled(self): - - class NoInit(checkpointable.Checkpointable): - - def __init__(self): - pass - - with self.assertRaisesRegexp(RuntimeError, "__init__"): - NoInit().add_variable("var", shape=[]) - - -class CheckpointingTests(test.TestCase): - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNamingWithOptimizer(self): - input_value = constant_op.constant([[3.]]) - network = MyNetwork() - # A nuisance Network using the same optimizer. Its slot variables should not - # go in the checkpoint, since it is never depended on. - other_network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Root(optimizer=optimizer, network=network) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value), - global_step=root_checkpointable.global_step) - optimizer.minimize( - lambda: other_network(input_value), - global_step=root_checkpointable.global_step) - else: - train_op = optimizer.minimize( - network(input_value), global_step=root_checkpointable.global_step) - optimizer.minimize( - other_network(input_value), - global_step=root_checkpointable.global_step) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - named_variables, serialized_graph = checkpointable._serialize_object_graph( - root_checkpointable) - expected_checkpoint_names = ( - # Created in the root node, so no prefix. - "global_step", - # No name provided to track_checkpointable(), so the position is used - # instead (one-based). - "network/via_track_layer/kernel", - # track_checkpointable() with a name provided, so that's used - "network/_named_dense/kernel", - "network/_named_dense/bias", - # non-Layer dependency of the network - "network/_non_layer/a_variable", - # The optimizer creates two non-slot variables - "_optimizer/beta1_power", - "_optimizer/beta2_power", - # Slot variables - "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/m", - "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/v", - "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m", - "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v", - "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", - "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v", - ) - six.assertCountEqual(self, expected_checkpoint_names, - named_variables.keys()) - # Check that we've mapped to the right variable objects (not exhaustive) - self.assertEqual("global_step:0", named_variables["global_step"].name) - self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", - named_variables["network/via_track_layer/kernel"].name) - self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", - named_variables["network/_named_dense/kernel"].name) - self.assertEqual("beta1_power:0", - named_variables["_optimizer/beta1_power"].name) - self.assertEqual("beta2_power:0", - named_variables["_optimizer/beta2_power"].name) - # Spot check the generated protocol buffers. - self.assertEqual("_optimizer", - serialized_graph.nodes[0].children[0].local_name) - optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ - 0].node_id] - self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) - self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) - # Variable ordering is arbitrary but deterministic (alphabetized) - self.assertEqual( - "bias", optimizer_node.slot_variables[0].original_variable_local_name) - original_variable_owner = serialized_graph.nodes[ - optimizer_node.slot_variables[0].original_variable_node_id] - self.assertEqual("network/_named_dense/bias", - original_variable_owner.variables[0].checkpoint_key) - self.assertEqual("bias", original_variable_owner.variables[0].local_name) - self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) - self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", - optimizer_node.slot_variables[0].checkpoint_key) - # We strip off the :0 suffix, as variable.name-based saving does. - self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam", - optimizer_node.slot_variables[0].full_name) - self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0", - optimizer.get_slot( - var=named_variables["network/_named_dense/bias"], - name="m").name) - - @test_util.run_in_graph_and_eager_modes() - def testSaveRestore(self): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root_checkpointable = Root(optimizer=optimizer, network=network) - input_value = constant_op.constant([[3.]]) - if context.in_eager_mode(): - optimizer.minimize( - lambda: network(input_value), - global_step=root_checkpointable.global_step) - else: - train_op = optimizer.minimize( - network(input_value), global_step=root_checkpointable.global_step) - self.evaluate(variables.global_variables_initializer()) - self.evaluate(train_op) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) - m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") - self.evaluate(state_ops.assign(m_bias_slot, [1.5])) - serialized_graph, save_path = checkpointable.save( - file_prefix=prefix, - root_checkpointable=root_checkpointable, - global_step=root_checkpointable.global_step) - self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) - self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) - optimizer_variables = self.evaluate(optimizer.variables()) - self.evaluate(state_ops.assign(m_bias_slot, [-2.])) - # Immediate restoration - checkpointable.restore( - save_path=save_path, - root_checkpointable=root_checkpointable, - object_graph_proto=serialized_graph) - self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) - self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) - self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) - with ops.Graph().as_default(): - on_create_network = MyNetwork() - on_create_optimizer = CheckpointableAdam(0.001) - on_create_root = Root( - optimizer=on_create_optimizer, network=on_create_network) - with self.test_session(graph=ops.get_default_graph()): - # Deferred restoration - checkpointable.restore( - save_path=save_path, - root_checkpointable=on_create_root, - object_graph_proto=serialized_graph) - on_create_network(constant_op.constant([[3.]])) # create variables - self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) - self.assertAllEqual([42.], - self.evaluate( - on_create_network._named_dense.variables[1])) - on_create_m_bias_slot = on_create_optimizer.get_slot( - on_create_network._named_dense.variables[1], "m") - # Optimizer slot variables are created when the original variable is - # restored. - self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) - # beta1_power and beta2_power haven't been created yet, but everything - # else matches. - self.assertAllEqual(optimizer_variables[2:], - self.evaluate(on_create_optimizer.variables())) - on_create_optimizer._create_slots( - [resource_variable_ops.ResourceVariable([1.])]) - beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() - self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) - self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) - - def testDeferredRestorationUsageEager(self): - """An idiomatic eager execution example.""" - num_training_steps = 10 - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - latest_object_graph = None # Will be saved with the checkpoint eventually. - for training_continuation in range(3): - with ops.Graph().as_default(): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Root(optimizer=optimizer, network=network) - checkpointable.restore( - save_path=core_saver.latest_checkpoint(checkpoint_directory), - root_checkpointable=root, - object_graph_proto=latest_object_graph) - for _ in range(num_training_steps): - # TODO(allenl): Use a Dataset and serialize/checkpoint it. - input_value = constant_op.constant([[3.]]) - optimizer.minimize( - lambda: network(input_value), # pylint: disable=cell-var-from-loop - global_step=root.global_step) - latest_object_graph, _ = checkpointable.save( - file_prefix=checkpoint_prefix, - root_checkpointable=root) - self.assertEqual((training_continuation + 1) * num_training_steps, - root.global_step.numpy()) - - def testUsageGraph(self): - """Expected usage when graph building.""" - with context.graph_mode(): - num_training_steps = 10 - checkpoint_directory = self.get_temp_dir() - checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") - latest_object_graph = None - for training_continuation in range(3): - with ops.Graph().as_default(): - network = MyNetwork() - optimizer = CheckpointableAdam(0.001) - root = Root(optimizer=optimizer, network=network) - input_value = constant_op.constant([[3.]]) - train_op = optimizer.minimize( - network(input_value), - global_step=root.global_step) - init_op = variables.global_variables_initializer() - checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) - with self.test_session(graph=ops.get_default_graph()) as session: - if checkpoint_path is None: - self.assertEqual(0, training_continuation) - session.run(init_op) - # Another alternative would be to run initializers automatically - # if no checkpoint is being loaded. This would make deferred - # loading a bit more useful with graph execution. - else: - checkpointable.restore( - save_path=checkpoint_path, - root_checkpointable=root, - object_graph_proto=latest_object_graph, - session=session) - for _ in range(num_training_steps): - session.run(train_op) - latest_object_graph, _ = checkpointable.save( - file_prefix=checkpoint_prefix, - root_checkpointable=root, - session=session) - self.assertEqual((training_continuation + 1) * num_training_steps, - session.run(root.global_step)) - - def _get_checkpoint_name(self, name): - root = checkpointable.Checkpointable() - root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) - named_variables, _ = checkpointable._serialize_object_graph(root) - checkpoint_name, = named_variables.keys() - with ops.name_scope("root/" + checkpoint_name): - pass # Make sure we can use this as an op name if we prefix it. - return checkpoint_name - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testVariableNameEscaping(self): - self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) - self.assertEqual(r"b", self._get_checkpoint_name(r"b")) - self.assertEqual(r"c_S__", self._get_checkpoint_name(r"c/")) - self.assertEqual(r"d_S___S_._", self._get_checkpoint_name(r"d/_S__")) - - @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) - def testNumberedPath(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - root.track_checkpointable(leaf, name="leaf") - leaf.add_variable(name="v", shape=[]) - named_variables, _ = checkpointable._serialize_object_graph(root) - variable_name, = named_variables.keys() - self.assertEqual(r"leaf/v", variable_name) - - @test_util.run_in_graph_and_eager_modes() - def testLocalNameValidation(self): - root = checkpointable.Checkpointable() - leaf = checkpointable.Checkpointable() - with self.assertRaisesRegexp(ValueError, "invalid name"): - # Leading dashes are reserved, which avoids conflicts with un-named edges - # in paths and the optimizer slots identifier. - root.track_checkpointable(leaf, name="-unnamed-12") - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/checkpointable_utils.py b/tensorflow/contrib/eager/python/checkpointable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e57093bdbc34660c5a6d61fb5af46bcbbbb5f524 --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_utils.py @@ -0,0 +1,772 @@ +"""Utilities for working with Checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import collections +import weakref + +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python import pywrap_tensorflow +from tensorflow.python.client import session as session_lib +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpointable as core_checkpointable +from tensorflow.python.training import checkpointable_utils as core_checkpointable_utils +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import deprecation + + +_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names. + +# Keyword for identifying that the next bit of a checkpoint variable name is a +# slot name. Checkpoint names for slot variables look like: +# +# /<_OPTIMIZER_SLOTS_NAME>// +# +# Where is a full path from the checkpoint root to the +# variable being slotted for. +_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT" +# Keyword for separating the path to an object from the name of an +# attribute in checkpoint names. Used like: +# /<_OBJECT_ATTRIBUTES_NAME>/ +_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES" +# Key where the object graph proto is saved in a TensorBundle +_OBJECT_GRAPH_PROTO_KEY = "_CHECKPOINTABLE_OBJECT_GRAPH" + + +# TODO(allenl): If this ends up in a public API, consider adding LINT.IfChange +# or consolidating the implementation with get_variable. +def _default_getter(name, shape, dtype, initializer=None, + partition_info=None, **kwargs): + """A pared-down version of get_variable which does not reuse variables.""" + dtype = dtypes.as_dtype(dtype) + shape_object = tensor_shape.as_shape(shape) + with ops.init_scope(): + if initializer is None: + initializer, initializing_from_value = ( + variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access + name=name, shape=shape_object, dtype=dtype)) + else: + initializing_from_value = not callable(initializer) + # Same logic as get_variable + variable_dtype = dtype.base_dtype + if initializing_from_value: + if shape is not None: + raise ValueError("If initializer is a constant, do not specify shape.") + initial_value = initializer + else: + # Instantiate initializer if provided initializer is a type object. + if isinstance(initializer, type(init_ops.Initializer)): + initializer = initializer(dtype=dtype) + def initial_value(): + return initializer( + shape_object.as_list(), dtype=dtype, partition_info=partition_info) + return resource_variable_ops.ResourceVariable( + initial_value=initial_value, + name=name, + dtype=variable_dtype, + **kwargs + ) + + +def add_variable(checkpointable, name, shape=None, dtype=dtypes.float32, + initializer=None): + """Add a variable to a Checkpointable with no scope influence.""" + return checkpointable._add_variable_with_custom_getter( # pylint: disable=protected-access + name=name, shape=shape, dtype=dtype, + initializer=initializer, getter=_default_getter) + + +def _breadth_first_checkpointable_traversal(root_checkpointable): + """Find shortest paths to all variables owned by dependencies of root.""" + bfs_sorted = [] + to_visit = collections.deque([root_checkpointable]) + path_to_root = {root_checkpointable: ()} + while to_visit: + current_checkpointable = to_visit.popleft() + current_checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access + bfs_sorted.append(current_checkpointable) + for child_checkpointable in ( + current_checkpointable._checkpoint_dependencies): # pylint: disable=protected-access + if child_checkpointable.ref not in path_to_root: + path_to_root[child_checkpointable.ref] = ( + path_to_root[current_checkpointable] + (child_checkpointable,)) + to_visit.append(child_checkpointable.ref) + return bfs_sorted, path_to_root + + +def _escape_local_name(name): + # We need to support slashes in local names for compatibility, since this + # naming scheme is being patched in to things like Layer.add_variable where + # slashes were previously accepted. We also want to use slashes to indicate + # edges traversed to reach the variable, so we escape forward slashes in + # names. + return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR) + .replace(r"/", _ESCAPE_CHAR + "S")) + + +def _object_prefix_from_path(path_to_root): + return "/".join( + (_escape_local_name(checkpointable.name) + for checkpointable in path_to_root)) + + +def _slot_variable_naming_for_optimizer(optimizer_path): + """Make a function for naming slot variables in an optimizer.""" + # Name slot variables: + # + # /<_OPTIMIZER_SLOTS_NAME>// + # + # where is exactly the checkpoint name used for the original + # variable, including the path from the checkpoint root and the local name in + # the object which owns it. Note that we only save slot variables if the + # variable it's slotting for is also being saved. + + optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path) + + def _name_slot_variable(variable_path, slot_name): + """With an optimizer specified, name a slot variable.""" + return (variable_path + + optimizer_identifier + + _escape_local_name(slot_name)) + + return _name_slot_variable + + +def _serialize_slot_variables(checkpointable_objects, node_ids, object_names): + """Gather and name slot variables.""" + non_slot_objects = list(checkpointable_objects) + slot_variables = {} + for checkpointable in non_slot_objects: + if isinstance(checkpointable, optimizer_lib.Optimizer): + naming_scheme = _slot_variable_naming_for_optimizer( + optimizer_path=object_names[checkpointable]) + slot_names = checkpointable.get_slot_names() + for slot_name in slot_names: + for original_variable_node_id, original_variable in enumerate( + non_slot_objects): + try: + slot_variable = checkpointable.get_slot( + original_variable, slot_name) + except AttributeError: + slot_variable = None + if slot_variable is None: + continue + slot_variable._maybe_initialize_checkpointable() # pylint: disable=protected-access + if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access + # TODO(allenl): Gather dependencies of slot variables. + raise NotImplementedError( + "Currently only variables with no dependencies can be saved as " + "slot variables. File a feature request if this limitation " + "bothers you.") + if slot_variable in node_ids: + raise NotImplementedError( + "A slot variable was re-used as a dependency of a " + "Checkpointable object. This is not currently allowed. File a " + "feature request if this limitation bothers you.") + checkpoint_name = naming_scheme( + variable_path=object_names[original_variable], + slot_name=slot_name) + object_names[slot_variable] = checkpoint_name + slot_variable_node_id = len(checkpointable_objects) + node_ids[slot_variable] = slot_variable_node_id + checkpointable_objects.append(slot_variable) + slot_variable_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph + .Object.SlotVariableReference( + slot_name=slot_name, + original_variable_node_id=original_variable_node_id, + slot_variable_node_id=slot_variable_node_id)) + slot_variables.setdefault(checkpointable, []).append( + slot_variable_proto) + return slot_variables + + +def _serialize_checkpointables( + checkpointable_objects, node_ids, object_names, slot_variables): + """Name non-slot `Checkpointable`s and add them to `object_graph_proto`.""" + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + named_saveables = {} + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + assert node_ids[checkpointable] == checkpoint_id + object_proto = object_graph_proto.nodes.add() + object_proto.slot_variables.extend(slot_variables.get(checkpointable, ())) + object_name = object_names[checkpointable] + for name, saveable in ( + checkpointable._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access + attribute = object_proto.attributes.add() + attribute.name = name + attribute.checkpoint_key = "%s/%s/%s" % ( + object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name)) + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [saveable], convert_variable_to_tensor=False) + attribute.full_name, = saver_dict.keys() + named_saveables[attribute.checkpoint_key] = saveable + + for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access + child_proto = object_proto.children.add() + child_proto.node_id = node_ids[child.ref] + child_proto.local_name = child.name + + return named_saveables, object_graph_proto + + +def _serialize_object_graph(root_checkpointable): + """Determine checkpoint keys for variables and build a serialized graph. + + Non-slot variables are keyed based on a shortest path from the root saveable + to the object which owns the variable (i.e. the one which called + `Checkpointable._add_variable` to create it). + + Slot variables are keyed based on a shortest path to the variable being + slotted for, a shortest path to their optimizer, and the slot name. + + Args: + root_checkpointable: A `Checkpointable` object whose variables (including + the variables of dependencies, recursively) should be saved. + + Returns: + A tuple of (named_variables, object_graph_proto): + named_variables: A dictionary mapping names to variable objects. + object_graph_proto: A CheckpointableObjectGraph protocol buffer containing + the serialized object graph and variable references. + + Raises: + ValueError: If there are invalid characters in an optimizer's slot names. + """ + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_names = { + obj: _object_prefix_from_path(path) + for obj, path in path_to_root.items()} + node_ids = {node: node_id for node_id, node + in enumerate(checkpointable_objects)} + slot_variables = _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return _serialize_checkpointables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names, + slot_variables=slot_variables) + + +def gather_initializers(root_checkpointable): + """Traverse the object graph and find initialization ops. + + Looks for `Checkpointable` objects which are dependencies of + `root_checkpointable` and which have an `initializer` property. Includes + initializers for slot variables only if the variable they are slotting for and + the optimizer are dependencies of `root_checkpointable` (i.e. if they would be + saved with a checkpoint). + + Args: + root_checkpointable: A `Checkpointable` object to gather initializers for. + Returns: + A list of initialization ops. + """ + # TODO(allenl): Extract out gathering logic so the naming logic doesn't have + # to run. + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_names = { + obj: _object_prefix_from_path(path) + for obj, path in path_to_root.items()} + node_ids = {node: node_id for node_id, node + in enumerate(checkpointable_objects)} + _serialize_slot_variables( + checkpointable_objects=checkpointable_objects, + node_ids=node_ids, + object_names=object_names) + return [c.initializer for c in checkpointable_objects + if hasattr(c, "initializer") and c.initializer is not None] + + +class _NoRestoreSaveable(saver_lib.BaseSaverBuilder.SaveableObject): + + def __init__(self, tensor, name): + spec = saver_lib.BaseSaverBuilder.SaveSpec(tensor, "", name) + super(_NoRestoreSaveable, self).__init__(tensor, [spec], name) + + def restore(self, restored_tensors, restored_shapes): + return control_flow_ops.no_op() + + +class _LoadStatus(object): + """Abstract base for load status callbacks.""" + + @abc.abstractmethod + def assert_consumed(self): + """Raises an exception unless a non-trivial restoration has completed.""" + pass + + @abc.abstractmethod + def run_restore_ops(self, session=None): + """Runs restore ops from the checkpoint. Requires a valid checkpoint.""" + pass + + @abc.abstractmethod + def initialize_or_restore(self, session=None): + """Runs restore ops from the checkpoint, or initializes variables.""" + pass + + +class CheckpointLoadStatus(_LoadStatus): + """Checks the status of checkpoint loading and manages restore ops. + + Returned from `Saver.restore`. Since `restore` may defer the loading of values + in the checkpoint which don't yet have corresponding Python objects, + `CheckpointLoadStatus` provides a callback to verify that checkpoint loading + is complete (`assert_consumed`). + + When graph building, `restore` does not run restore ops itself since their + creation may be deferred. The `run_restore_ops` method must be called once all + Python objects with values to restore have been created and added to the + dependency graph (this does not necessarily have to be the whole checkpoint; + calling `run_restore_ops` while `assert_consumed` fails is supported and will + partially restore the checkpoint). + + See `Saver.restore` for usage examples. + """ + + def __init__(self, checkpoint, feed_dict): + self._checkpoint = checkpoint + self._feed_dict = feed_dict + + def assert_consumed(self): + """Asserts that all objects in the checkpoint have been created/matched. + + Returns: + `self` for chaining. + Raises: + AssertionError: If there are any Python objects in the dependency graph + which have not been restored from this checkpoint or a later `restore`, + or if there are any checkpointed values which have not been matched to + Python objects. + """ + for node_id, node in enumerate(self._checkpoint.object_graph_proto.nodes): + checkpointable = self._checkpoint.object_by_proto_id.get(node_id, None) + if checkpointable is None: + raise AssertionError("Unresolved object in checkpoint: %s" % (node,)) + if checkpointable._update_uid < self._checkpoint.restore_uid: # pylint: disable=protected-access + raise AssertionError( + "Object not assigned a value from checkpoint: %s" % (node,)) + if self._checkpoint.slot_restorations: + # Sanity check; this collection should be clear if everything has been + # restored. + raise AssertionError("Unresolved slot restorations: %s" % ( + self._checkpoint.slot_restorations,)) + if self._checkpoint.unused_attributes: + raise AssertionError( + ("Unused attributes in these objects (the attributes exist in the " + "checkpoint but not in the objects): %s") % ( + self._checkpoint.unused_attributes.items(),)) + return self + + def run_restore_ops(self, session=None): + """Run operations to restore objects in the dependency graph.""" + if context.in_eager_mode(): + return # Run eagerly + if session is None: + session = ops.get_default_session() + session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict) + + def initialize_or_restore(self, session=None): + """Alias for `run_restore_ops`. + + This method has a sibling in `InitializationOnlyStatus` which instead + initializes variables. That type is returned if no checkpoint is specified + in `Saver.restore`. + + Args: + session: The session to run restore ops in. If `None`, uses the default + session. + """ + self.run_restore_ops(session=session) + + +class InitializationOnlyStatus(_LoadStatus): + """Returned from `Saver.restore` when no checkpoint has been specified. + + Objects of this type have the same `assert_consumed` method as + `CheckpointLoadStatus`, but it always fails. However, + `initialize_or_restore` works on objects of both types, and will + initialize variables in `InitializationOnlyStatus` objects or restore them + otherwise. + """ + + def __init__(self, root_checkpointable): + self._root_checkpointable = root_checkpointable + + def assert_consumed(self): + """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" + raise AssertionError( + "No checkpoint specified (save_path=None); nothing is being restored.") + + def run_restore_ops(self, session=None): + """For consistency with `CheckpointLoadStatus`. + + Use `initialize_or_restore` for initializing if no checkpoint was passed + to `Saver.restore` and restoring otherwise. + + Args: + session: Not used. + """ + raise AssertionError( + "No checkpoint specified, so no restore ops are available " + "(save_path=None to Saver.restore).") + + def initialize_or_restore(self, session=None): + """Runs initialization ops for variables. + + Only objects which would be saved by `Saver.save` will be initialized. See + `gather_initializers` for details. + + This method does nothing when executing eagerly (initializers get run + eagerly). + + Args: + session: The session to run initialization ops in. If `None`, uses the + default session. + """ + if context.in_eager_mode(): + return # run eagerly + if session is None: + session = ops.get_default_session() + session.run(gather_initializers(self._root_checkpointable)) + + +_DEPRECATED_RESTORE_INSTRUCTIONS = ( + "Restoring a name-based tf.train.Saver checkpoint using the object-based " + "restore API. This mode uses global names to match variables, and so is " + "somewhat fragile. It also adds new restore ops to the graph each time it " + "is called. Prefer re-encoding training checkpoints in the object-based " + "format: run save() on the object-based saver (the same one this message " + "is coming from) and use that checkpoint in the future.") + + +class NameBasedSaverStatus(_LoadStatus): + """Status for loading a name-based training checkpoint.""" + + def __init__(self, object_saver, save_path): + self._object_saver = object_saver + self._save_path = save_path + + def assert_consumed(self): + """Assertion for consistency with `CheckpointLoadStatus`. Always fails.""" + raise AssertionError( + "Restoring a name-based checkpoint. No load status is available.") + + @deprecation.deprecated( + date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS) + def run_restore_ops(self, session=None): + """Load the name-based training checkpoint using a new `tf.train.Saver`.""" + if session is None and context.in_graph_mode(): + session = ops.get_default_session() + saver_lib.Saver(self._object_saver._global_variable_names()).restore( # pylint: disable=protected-access + sess=session, save_path=self._save_path) + + def initialize_or_restore(self, session=None): + """Alias for `run_restore_ops`.""" + self.run_restore_ops(session=session) + + +class _SessionWithFeedDictAdditions(session_lib.SessionInterface): + """Pretends to be a session, inserts extra feeds on run().""" + + def __init__(self, session, feed_additions): + self._wrapped_session = session + self._feed_additions = feed_additions + + def run(self, fetches, feed_dict=None, **kwargs): + if feed_dict is None: + feed_dict = {} + else: + feed_dict = feed_dict.copy() + feed_dict.update(self._feed_additions) + return self._wrapped_session.run( + fetches=fetches, feed_dict=feed_dict, **kwargs) + + +class Saver(object): + """Saves and restores a `Checkpointable` object and its dependencies. + + See `Checkpointable` for details of dependency management. `Saver` wraps + `tf.train.Saver` for saving, including extra information about the graph of + dependencies between Python objects. When restoring, it uses this information + about the save-time dependency graph to more robustly match objects with their + checkpointed values. When executing eagerly, it supports restoring variables + on object creation (see `Saver.restore`). + + Values in a checkpoint are mapped to `Checkpointable` Python objects + (`Variable`s, `Optimizer`s, `Layer`s) based on the names provided when the + checkpoint was written. To avoid breaking existing checkpoints when modifying + a class, dependency names (the names of attributes to which `Checkpointable` + objects are assigned) may not change. These names are local to objects, in + contrast to the `Variable.name`-based save/restore from `tf.train.Saver`, and + so allow additional program transformations. + """ + + def __init__(self, root_checkpointable): + """Configure saving. + + Args: + root_checkpointable: The root of the object graph to save/restore. This + object and all of its dependencies are saved in the checkpoint. When + restoring, objects are matched and restored starting from this root. + """ + # Allow passing in a weak reference to avoid reference cycles when + # `Checkpointable` objects save themselves. + self._root_checkpointable_ref = root_checkpointable + if context.in_graph_mode(): + self._file_prefix_placeholder = constant_op.constant("model") + else: + self._file_prefix_placeholder = None + + # Op caching for save + self._object_graph_feed_tensor = None + self._last_save_object_graph = None + self._last_save_saver = None + + # Op caching for restore + self._object_graph_restore_tensor = None + self._last_restore_object_graph = None + self._last_restore_checkpoint = None + + @property + def _root_checkpointable(self): + if isinstance(self._root_checkpointable_ref, weakref.ref): + derefed = self._root_checkpointable_ref() + assert derefed is not None + return derefed + else: + return self._root_checkpointable_ref + + def save(self, file_prefix, checkpoint_number=None, session=None): + """Save a training checkpoint. + + The saved checkpoint includes variables created by this object and any + Checkpointable objects it depends on at the time `Saver.save()` is called. + + Args: + file_prefix: A prefix to use for the checkpoint filenames + (/path/to/directory/and_a_prefix). Names are generated based on this + prefix and `checkpoint_number`, if provided. + checkpoint_number: An integer variable or Tensor, used to number + checkpoints. Typically this value is saved along with other variables in + training checkpoints, which will happen automatically if it was created + by `root_checkpointable` or one of its dependencies (via + `Checkpointable._add_variable`). + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + + Returns: + The full path to the checkpoint. + """ + named_variables, graph_proto = _serialize_object_graph( + self._root_checkpointable) + in_graph_mode = context.in_graph_mode() + if in_graph_mode: + if session is None: + session = ops.get_default_session() + if self._object_graph_feed_tensor is None: + self._object_graph_feed_tensor = constant_op.constant( + "", dtype=dtypes.string) + object_graph_tensor = self._object_graph_feed_tensor + feed_additions = {object_graph_tensor: graph_proto.SerializeToString()} + else: + session = None + object_graph_tensor = constant_op.constant( + graph_proto.SerializeToString(), dtype=dtypes.string) + feed_additions = None + assert _OBJECT_GRAPH_PROTO_KEY not in named_variables + named_variables[_OBJECT_GRAPH_PROTO_KEY] = _NoRestoreSaveable( + tensor=object_graph_tensor, + name=_OBJECT_GRAPH_PROTO_KEY) + if not in_graph_mode or self._last_save_object_graph != graph_proto: + if self._last_save_object_graph is not None and in_graph_mode: + raise NotImplementedError( + "Using a single Saver to save a mutated object graph is not " + "currently supported when graph building. Use a different Saver " + "when the object graph changes (save ops will be duplicated), or " + "file a feature request if this limitation bothers you.") + saver = saver_lib.Saver(var_list=named_variables) + if in_graph_mode: + self._last_save_saver = saver + self._last_save_object_graph = graph_proto + else: + saver = self._last_save_saver + save_path = saver.save( + sess=_SessionWithFeedDictAdditions( + session=session, feed_additions=feed_additions), + save_path=file_prefix, + write_meta_graph=False, + global_step=checkpoint_number) + return save_path + + def _global_variable_names(self): + """Generate a `tf.train.Saver`-style `var_list` using `variable.name`s.""" + named_saveables, graph_proto = _serialize_object_graph( + self._root_checkpointable) + saver_names = {} + for object_proto in graph_proto.nodes: + for attribute_proto in object_proto.attributes: + saver_names[attribute_proto.full_name] = named_saveables[ + attribute_proto.checkpoint_key] + return saver_names + + def restore(self, save_path, session=None): + """Restore a training checkpoint. + + Restores `root_checkpointable` and any objects that it tracks + (transitive). Either assigns values immediately if variables to restore have + been created already, or defers restoration until the variables are + created. Dependencies added to the `root_checkpointable` passed to the + constructor after this call will be matched if they have a corresponding + object in the checkpoint. + + When building a graph, restorations are added to the graph but not run. A + session is required to retrieve checkpoint metadata. + + To disallow deferred loading, assert immediately that all checkpointed + variables have been matched to variable objects: + + ```python + saver = Saver(root) + saver.restore(path).assert_consumed() + ``` + + An exception will be raised unless every object was matched and its + variables already exist. + + When graph building, `assert_consumed()` indicates that all of the restore + ops which will be created for this checkpoint have been created. They can be + run via the `run_restore_ops()` function of the status object: + + ```python + saver.restore(path).assert_consumed().run_restore_ops() + ``` + + If the checkpoint has not been consumed completely, then the list of restore + ops will grow as more objects are added to the dependency graph. + + Name-based `tf.train.Saver` checkpoints can be loaded using this + method. There is no deferred loading, and names are used to match + variables. No restore ops are created/run until `run_restore_ops()` or + `initialize_or_restore()` are called on the returned status object, even + when executing eagerly. Re-encode name-based checkpoints using this + object-based `Saver.save` as soon as possible. + + Args: + save_path: The path to the checkpoint, as returned by `save` or + `tf.train.latest_checkpoint`. If None (as when there is no latest + checkpoint for `tf.train.latest_checkpoint` to return), returns an + object which may run initializers for objects in the dependency + graph. If the checkpoint was written by the name-based `tf.train.Saver`, + names are used to match variables. + session: The session to retrieve metadata with. Ignored when executing + eagerly. If not provided when graph building, the default session is + used. + + Returns: + A load status object, which can be used to make assertions about the + status of checkpoint restoration and run initialization/restore ops + (of type `CheckpointLoadStatus`, or `InitializationOnlyStatus` if + `save_path` is `None`). + + If `save_path` points to a name-based checkpoint, a `NameBasedSaverStatus` + object is returned which runs restore ops from a name-based saver. + """ + if save_path is None: + return InitializationOnlyStatus(self._root_checkpointable) + in_graph_mode = context.in_graph_mode() + if in_graph_mode: + if session is None: + session = ops.get_default_session() + file_prefix_tensor = self._file_prefix_placeholder + file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} + else: + session = None + file_prefix_tensor = constant_op.constant(save_path) + file_prefix_feed_dict = None + try: + if not in_graph_mode or self._object_graph_restore_tensor is None: + object_graph_string, = io_ops.restore_v2( + prefix=file_prefix_tensor, + tensor_names=[_OBJECT_GRAPH_PROTO_KEY], + shape_and_slices=[""], + dtypes=[dtypes.string], + name="object_graph_proto_read") + if in_graph_mode: + self._object_graph_restore_tensor = object_graph_string + if in_graph_mode: + object_graph_string = session.run( + self._object_graph_restore_tensor, + feed_dict=file_prefix_feed_dict) + else: + object_graph_string = object_graph_string.numpy() + except errors_impl.NotFoundError: + # The object graph proto does not exist in this checkpoint. Try again with + # name-based saving. + return NameBasedSaverStatus(self, save_path) + + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + object_graph_proto.ParseFromString(object_graph_string) + if in_graph_mode and object_graph_proto == self._last_restore_object_graph: + checkpoint = self._last_restore_checkpoint + else: + if in_graph_mode: + dtype_map = None + else: + reader = pywrap_tensorflow.NewCheckpointReader(save_path) + dtype_map = reader.get_variable_to_dtype_map() + checkpoint = core_checkpointable_utils._Checkpoint( # pylint: disable=protected-access + object_graph_proto=object_graph_proto, + save_path=file_prefix_tensor, + dtype_map=dtype_map) + if in_graph_mode: + if self._last_restore_object_graph is not None: + raise NotImplementedError( + "Using a single Saver to restore different object graphs is not " + "currently supported when graph building. Use a different Saver " + "for each object graph (restore ops will be duplicated), or " + "file a feature request if this limitation bothers you.") + self._last_restore_checkpoint = checkpoint + self._last_restore_object_graph = object_graph_proto + core_checkpointable._CheckpointPosition( # pylint: disable=protected-access + checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable) + load_status = CheckpointLoadStatus( + checkpoint, feed_dict=file_prefix_feed_dict) + return load_status diff --git a/tensorflow/contrib/eager/python/checkpointable_utils_test.py b/tensorflow/contrib/eager/python/checkpointable_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6a200276754d96b6c539cc98c397d09b999b9f --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_utils_test.py @@ -0,0 +1,1013 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os +import weakref + +import six + +from tensorflow.contrib.eager.python import checkpointable_utils +from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import base +from tensorflow.python.layers import core +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import adam +from tensorflow.python.training import checkpointable +from tensorflow.python.training import saver as core_saver +from tensorflow.python.training import training_util + + +class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + core.Dense.__init__(self, *args, **kwargs) + + def add_variable(self, name, shape, **kwargs): + # Calls both Checkpointable._add_variable and Layer.add_variable. Eventually + # Layer.add_variable should inherit from Checkpointable and simply call + # super and then do post-processing. + return checkpointable.Checkpointable._add_variable_with_custom_getter( + self, + name=name, + shape=shape, + getter=functools.partial(core.Dense.add_variable, self), + **kwargs) + + +# pylint: disable=not-callable +class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): + + def __setattr__(self, name, value): + if isinstance(value, base.Layer): + self.track_layer(value, name=name) + # Checkpointable is next in the method resolution order, so this will catch + # Checkpointable objects which aren't Layers. + super(CheckpointableNetwork, self).__setattr__(name, value) + + def track_layer(self, layer, name): + self._track_checkpointable(layer, name=name) + return super(CheckpointableNetwork, self).track_layer(layer) + + +class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): + + # NOTE: Copied from Optimizer with modifications to use add_variable + # for non-slot variables. These contortions are necessary to maintain + # checkpoint compatibility with variable.name based saving. + # TODO(allenl): Make this cleaner. + def _create_non_slot_variable(self, initial_value, name, colocate_with): + """Add an extra variable, not associated with a slot.""" + if context.in_graph_mode(): + graph = colocate_with.graph + else: + graph = None + + key = (name, graph) + v = self._non_slot_dict.get(key, None) + if v is None: + with ops.colocate_with(colocate_with): + def _variable_getter(name, shape, dtype, initializer): + del shape, dtype # not used, but there for compatibility + return variable_scope.variable( + name=name, initial_value=initializer, trainable=False) + + initial_value = ops.convert_to_tensor(initial_value) + v = self._add_variable_with_custom_getter( + name=name, + shape=initial_value.get_shape(), + initializer=initial_value, + getter=_variable_getter) + + self._non_slot_dict[key] = v + + return v + + +class NonLayerCheckpointable(checkpointable.Checkpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = checkpointable_utils.add_variable( + self, name="a_variable", shape=[]) + + +class MyNetwork(CheckpointableNetwork): + """A concrete Network for testing.""" + + def __init__(self): + super(MyNetwork, self).__init__() + self._named_dense = CheckpointableDenseLayer(1, use_bias=True) + self._via_track_layer = self.track_layer( + CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + return self._via_track_layer(self._named_dense(values)) + + +class Checkpoint(checkpointable.Checkpointable): + """A utility class which groups `Checkpointable` objects.""" + + def __init__(self, **kwargs): + super(Checkpoint, self).__init__() + for k, v in sorted(kwargs.items(), key=lambda item: item[0]): + setattr(self, k, v) + self._save_counter = None # Created lazily for restore-on-create. + self._saver = checkpointable_utils.Saver(weakref.ref(self)) + + @property + def save_counter(self): + """An integer variable which starts at zero and is incremented on save. + + Used to number checkpoints. + + Returns: + The save counter variable. + """ + if self._save_counter is None: + # Initialized to 0 and incremented before saving. + self._save_counter = checkpointable_utils.add_variable( + self, name="save_counter", initializer=0, dtype=dtypes.int64) + return self._save_counter + + def save(self, file_prefix, session=None): + assign_op = self.save_counter.assign_add(1) + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + session.run(assign_op) + return self._saver.save( + file_prefix=file_prefix, + checkpoint_number=self.save_counter, + session=session) + + def restore(self, save_path): + status = self._saver.restore(save_path=save_path) + # Create the save counter now so it gets initialized with other variables + # when graph building. Creating it earlier would lead to double + # initialization when executing eagerly. + self.save_counter # pylint: disable=pointless-statement + return status + + +class InterfaceTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testAddVariable(self): + obj = NonLayerCheckpointable() + with self.assertRaisesRegexp(ValueError, "do not specify shape"): + checkpointable_utils.add_variable( + obj, name="shape_specified_twice", shape=[], initializer=1) + constant_initializer = checkpointable_utils.add_variable( + obj, name="constant_initializer", initializer=1) + with variable_scope.variable_scope("some_variable_scope"): + ones_initializer = checkpointable_utils.add_variable( + obj, + name="ones_initializer", + shape=[2], + initializer=init_ops.ones_initializer(dtype=dtypes.float32)) + bare_initializer = checkpointable_utils.add_variable( + obj, + name="bare_initializer", + shape=[2, 2], + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + + # Even in graph mode, there are no naming conflicts between objects, only + # naming conflicts within an object. + other_duplicate = resource_variable_ops.ResourceVariable( + name="duplicate", initial_value=1.) + duplicate = checkpointable_utils.add_variable( + obj, name="duplicate", shape=[]) + with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): + checkpointable_utils.add_variable(obj, name="duplicate", shape=[]) + + self.evaluate(checkpointable_utils.gather_initializers(obj)) + self.assertEqual("constant_initializer:0", constant_initializer.name) + self.assertEqual(1, self.evaluate(constant_initializer)) + self.assertEqual("some_variable_scope/ones_initializer:0", + ones_initializer.name) + self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) + self.assertAllEqual([[0., 0.], + [0., 0.]], self.evaluate(bare_initializer)) + self.assertEqual("a_variable:0", obj.a_variable.name) + self.assertEqual("duplicate:0", other_duplicate.name) + if context.in_graph_mode(): + # The .name attribute may be globally influenced, but the checkpoint name + # won't be (tested below). + self.assertEqual("duplicate_1:0", duplicate.name) + else: + # When executing eagerly, there's no uniquification of variable names. The + # checkpoint name will be the same. + self.assertEqual("duplicate:0", duplicate.name) + named_variables, _ = checkpointable_utils._serialize_object_graph(obj) + expected_checkpoint_names = ( + "a_variable/.ATTRIBUTES/VARIABLE_VALUE", + "bare_initializer/.ATTRIBUTES/VARIABLE_VALUE", + "constant_initializer/.ATTRIBUTES/VARIABLE_VALUE", + "duplicate/.ATTRIBUTES/VARIABLE_VALUE", + "ones_initializer/.ATTRIBUTES/VARIABLE_VALUE", + ) + six.assertCountEqual( + self, expected_checkpoint_names, named_variables.keys()) + + def testInitNotCalled(self): + + class NoInit(checkpointable.Checkpointable): + + def __init__(self): + pass + + # __init__ for Checkpointable will be called implicitly. + checkpointable_utils.add_variable(NoInit(), "var", shape=[]) + + def testShapeDtype(self): + root = checkpointable.Checkpointable() + v1 = checkpointable_utils.add_variable( + root, name="v1", initializer=3., dtype=dtypes.float64) + self.assertEqual(dtypes.float64, v1.dtype) + v2 = checkpointable_utils.add_variable( + root, + name="v2", + shape=[3], + initializer=init_ops.ones_initializer, + dtype=dtypes.float64) + self.assertEqual(dtypes.float64, v2.dtype) + self.assertAllEqual([1., 1., 1.], self.evaluate(v2)) + + +class CheckpointingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + # A nuisance Network using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = Checkpoint( + optimizer=optimizer, network=network, optimizer_step=optimizer_step) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=optimizer_step) + optimizer.minimize( + lambda: other_network(input_value), + global_step=optimizer_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=optimizer_step) + optimizer.minimize( + other_network(input_value), + global_step=optimizer_step) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + named_variables, serialized_graph = ( + checkpointable_utils._serialize_object_graph(root_checkpointable)) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "optimizer_step", + # No name provided to track_checkpointable(), so the position is used + # instead (one-based). + "network/via_track_layer/kernel", + # track_checkpointable() with a name provided, so that's used + "network/_named_dense/kernel", + "network/_named_dense/bias", + # non-Layer dependency of the network + "network/_non_layer/a_variable", + # The optimizer creates two non-slot variables + "optimizer/beta1_power", + "optimizer/beta2_power", + # Slot variables + "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/m", + "network/via_track_layer/kernel/.OPTIMIZER_SLOT/optimizer/v", + "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m", + "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v", + "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m", + "network/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v", + ) + suffix = "/.ATTRIBUTES/VARIABLE_VALUE" + expected_checkpoint_names = [ + name + suffix for name in expected_checkpoint_names] + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual( + "global_step:0", + named_variables["optimizer_step" + suffix].name) + self.assertEqual( + "my_network/checkpointable_dense_layer_1/kernel:0", + named_variables["network/via_track_layer/kernel" + suffix].name) + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel:0", + named_variables["network/_named_dense/kernel" + suffix].name) + self.assertEqual( + "beta1_power:0", + named_variables["optimizer/beta1_power" + suffix].name) + self.assertEqual( + "beta2_power:0", + named_variables["optimizer/beta2_power" + suffix].name) + # Spot check the generated protocol buffers. + self.assertEqual("optimizer", + serialized_graph.nodes[0].children[1].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 1].node_id] + self.assertEqual("beta1_power", + optimizer_node.children[0].local_name) + self.assertEqual("beta1_power", + serialized_graph.nodes[optimizer_node.children[0].node_id] + .attributes[0].full_name) + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel", + serialized_graph.nodes[optimizer_node.slot_variables[0] + .original_variable_node_id] + .attributes[0].full_name) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel/Adam", + serialized_graph.nodes[optimizer_node.slot_variables[0] + .slot_variable_node_id] + .attributes[0].full_name) + self.assertEqual( + "my_network/checkpointable_dense_layer/kernel/Adam:0", + optimizer.get_slot( + var=named_variables["network/_named_dense/kernel" + suffix], + name="m").name) + self.assertEqual( + "network/_named_dense/kernel" + suffix, + serialized_graph.nodes[ + optimizer_node.slot_variables[0] + .original_variable_node_id].attributes[0].checkpoint_key) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + self.assertEqual( + "network/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix, + serialized_graph.nodes[ + optimizer_node.slot_variables[0] + .slot_variable_node_id].attributes[0].checkpoint_key) + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Checkpoint(optimizer=optimizer, network=network) + input_value = constant_op.constant([[3.]]) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value)) + else: + train_op = optimizer.minimize(network(input_value)) + # TODO(allenl): Make initialization more pleasant when graph building. + root_checkpointable.save_counter # pylint: disable=pointless-statement + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") + self.evaluate(state_ops.assign(m_bias_slot, [1.5])) + save_path = root_checkpointable.save(file_prefix=prefix) + self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) + self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3)) + optimizer_variables = self.evaluate(optimizer.variables()) + self.evaluate(state_ops.assign(m_bias_slot, [-2.])) + # Immediate restoration + status = root_checkpointable.restore(save_path=save_path).assert_consumed() + status.run_restore_ops() + self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) + self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter)) + self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) + if context.in_graph_mode(): + return # Restore-on-create is only supported when executing eagerly + on_create_network = MyNetwork() + on_create_optimizer = CheckpointableAdam(0.001) + on_create_root = Checkpoint( + optimizer=on_create_optimizer, network=on_create_network) + # Deferred restoration + status = on_create_root.restore(save_path=save_path) + on_create_network(constant_op.constant([[3.]])) # create variables + self.assertAllEqual(1, self.evaluate(on_create_root.save_counter)) + self.assertAllEqual([42.], + self.evaluate( + on_create_network._named_dense.variables[1])) + on_create_m_bias_slot = on_create_optimizer.get_slot( + on_create_network._named_dense.variables[1], "m") + # Optimizer slot variables are created when the original variable is + # restored. + self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) + self.assertAllEqual(optimizer_variables[2:], + self.evaluate(on_create_optimizer.variables())) + on_create_optimizer._create_slots( + [resource_variable_ops.ResourceVariable([1.])]) + status.assert_consumed() + beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() + self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) + self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + + # TODO(allenl): Debug garbage created by this test in python3. + def testDeferredRestorationUsageEager(self): + """An idiomatic eager execution example.""" + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Checkpoint( + optimizer=optimizer, network=network, + optimizer_step=training_util.get_or_create_global_step()) + root.restore(core_saver.latest_checkpoint(checkpoint_directory)) + for _ in range(num_training_steps): + # TODO(allenl): Use a Dataset and serialize/checkpoint it. + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + lambda: network(input_value), # pylint: disable=cell-var-from-loop + global_step=root.optimizer_step) + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.optimizer_step.numpy()) + + def testUsageGraph(self): + """Expected usage when graph building.""" + with context.graph_mode(): + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Checkpoint( + optimizer=optimizer, network=network, + global_step=training_util.get_or_create_global_step()) + input_value = constant_op.constant([[3.]]) + train_op = optimizer.minimize( + network(input_value), + global_step=root.global_step) + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + with self.test_session(graph=ops.get_default_graph()) as session: + status = root.restore(save_path=checkpoint_path) + status.initialize_or_restore(session=session) + if checkpoint_path is None: + self.assertEqual(0, training_continuation) + with self.assertRaises(AssertionError): + status.assert_consumed() + else: + status.assert_consumed() + for _ in range(num_training_steps): + session.run(train_op) + root.save(file_prefix=checkpoint_prefix, session=session) + self.assertEqual((training_continuation + 1) * num_training_steps, + session.run(root.global_step)) + self.assertEqual(training_continuation + 1, + session.run(root.save_counter)) + + @test_util.run_in_graph_and_eager_modes() + def testAgnosticUsage(self): + """Graph/eager agnostic usage.""" + # Does create garbage when executing eagerly due to ops.Graph() creation. + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + for training_continuation in range(3): + with ops.Graph().as_default(), self.test_session( + graph=ops.get_default_graph()): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Checkpoint( + optimizer=optimizer, network=network, + global_step=training_util.get_or_create_global_step()) + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + status = root.restore(save_path=checkpoint_path) + input_value = constant_op.constant([[3.]]) + train_fn = functools.partial( + optimizer.minimize, + functools.partial(network, input_value), + global_step=root.global_step) + if context.in_graph_mode(): + train_fn = functools.partial(self.evaluate, train_fn()) + status.initialize_or_restore() + for _ in range(num_training_steps): + train_fn() + root.save(file_prefix=checkpoint_prefix) + self.assertEqual((training_continuation + 1) * num_training_steps, + self.evaluate(root.global_step)) + self.assertEqual(training_continuation + 1, + self.evaluate(root.save_counter)) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + checkpointable_utils.add_variable( + root, name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableNameEscaping(self): + suffix = "/.ATTRIBUTES/VARIABLE_VALUE" + self.assertEqual(r"a.Sb.Sc" + suffix, self._get_checkpoint_name(r"a/b/c")) + self.assertEqual(r"b" + suffix, self._get_checkpoint_name(r"b")) + self.assertEqual(r"c.S" + suffix, self._get_checkpoint_name(r"c/")) + self.assertEqual(r"d.S..S" + suffix, self._get_checkpoint_name(r"d/.S")) + self.assertEqual(r"d.S..ATTRIBUTES.Sf" + suffix, + self._get_checkpoint_name(r"d/.ATTRIBUTES/f")) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNumberedPath(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + root.leaf = leaf + checkpointable_utils.add_variable(leaf, name="v", shape=[]) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + variable_name, = named_variables.keys() + self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", variable_name) + + @test_util.run_in_graph_and_eager_modes() + def testLocalNameValidation(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + # Dots are escaped, which avoids conflicts with reserved names. + root._track_checkpointable(leaf, name=".ATTRIBUTES") + checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[]) + named_variables, _ = checkpointable_utils._serialize_object_graph(root) + name, = named_variables.keys() + self.assertEqual(name, "..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE") + + @test_util.run_in_graph_and_eager_modes() + def testLateDependencyTracking(self): + + class Dependency(checkpointable.Checkpointable): + + def build(self): + self.var = checkpointable_utils.add_variable( + self, "var", initializer=0.) + + class LateDependencies(checkpointable.Checkpointable): + + def add_dep(self): + self.dep = Dependency() + self.dep.build() + + original = LateDependencies() + original.add_dep() + self.evaluate(state_ops.assign(original.dep.var, 123.)) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpointable_utils.Saver(original).save(checkpoint_prefix) + load_into = LateDependencies() + status = checkpointable_utils.Saver(load_into).restore(save_path) + with self.assertRaises(AssertionError): + status.assert_consumed() + load_into.add_dep() + status.assert_consumed() + status.run_restore_ops() + self.assertEqual(123., self.evaluate(load_into.dep.var)) + + @test_util.run_in_graph_and_eager_modes() + def testDepAfterVar(self): + + class Dependency(checkpointable.Checkpointable): + + def build(self): + self.var = checkpointable_utils.add_variable( + self, "var", initializer=0.) + + class DepAfterVar(checkpointable.Checkpointable): + + def add_dep(self): + dep = Dependency() + dep.build() + self.dep = dep + + dep_after_var = DepAfterVar() + dep_after_var.add_dep() + self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.)) + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_path = checkpointable_utils.Saver(dep_after_var).save( + checkpoint_prefix) + + loaded_dep_after_var = DepAfterVar() + status = checkpointable_utils.Saver(loaded_dep_after_var).restore(save_path) + loaded_dep_after_var.add_dep() + status.assert_consumed() + status.run_restore_ops() + self.assertEqual(-14., self.evaluate(loaded_dep_after_var.dep.var)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testDeferredSlotRestoration(self): + checkpoint_directory = self.get_temp_dir() + + root = checkpointable.Checkpointable() + root.var = checkpointable_utils.add_variable( + root, name="var", initializer=0.) + optimizer = CheckpointableAdam(0.1) + if context.in_graph_mode(): + train_op = optimizer.minimize(root.var) + # Note that `optimizer` has not been added as a dependency of + # `root`. Create a one-off grouping so that slot variables for `root.var` + # get initialized too. + self.evaluate(checkpointable_utils.gather_initializers( + Checkpoint(root=root, optimizer=optimizer))) + self.evaluate(train_op) + else: + optimizer.minimize(root.var.read_value) + self.evaluate(state_ops.assign(root.var, 12.)) + no_slots_path = checkpointable_utils.Saver(root).save( + os.path.join(checkpoint_directory, "no_slots")) + root.optimizer = optimizer + self.evaluate(state_ops.assign(root.var, 13.)) + self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var), + 14.)) + slots_path = checkpointable_utils.Saver(root).save( + os.path.join(checkpoint_directory, "with_slots")) + new_root = checkpointable.Checkpointable() + # Load the slot-containing checkpoint (deferred), then immediately overwrite + # the non-slot variable (also deferred). + slot_status = checkpointable_utils.Saver(new_root).restore(slots_path) + no_slot_status = checkpointable_utils.Saver(new_root).restore(no_slots_path) + with self.assertRaises(AssertionError): + no_slot_status.assert_consumed() + new_root.var = checkpointable_utils.add_variable( + new_root, name="var", shape=[]) + no_slot_status.assert_consumed() + no_slot_status.run_restore_ops() + self.assertEqual(12., self.evaluate(new_root.var)) + new_root.optimizer = CheckpointableAdam(0.1) + with self.assertRaisesRegexp(AssertionError, "beta1_power"): + slot_status.assert_consumed() + self.assertEqual(12., self.evaluate(new_root.var)) + if context.in_eager_mode(): + # Slot variables are only created with restoring initializers when + # executing eagerly. + self.assertEqual(14., self.evaluate( + new_root.optimizer.get_slot(name="m", var=new_root.var))) + else: + self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var), + None) + if context.in_graph_mode(): + train_op = new_root.optimizer.minimize(new_root.var) + # The slot variable now exists; restore() didn't create it, but we should + # now have a restore op for it. + slot_status.run_restore_ops() + self.assertEqual(14., self.evaluate( + new_root.optimizer.get_slot(name="m", var=new_root.var))) + self.evaluate(train_op) + else: + new_root.optimizer.minimize(new_root.var.read_value) + slot_status.assert_consumed() + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testOverlappingRestores(self): + checkpoint_directory = self.get_temp_dir() + save_root = checkpointable.Checkpointable() + save_root.dep = checkpointable.Checkpointable() + save_root.dep.var = checkpointable_utils.add_variable( + save_root.dep, name="var", initializer=0.) + self.evaluate(state_ops.assign(save_root.dep.var, 12.)) + saver = checkpointable_utils.Saver(save_root) + first_path = saver.save(os.path.join(checkpoint_directory, "first")) + self.evaluate(state_ops.assign(save_root.dep.var, 13.)) + second_path = saver.save(os.path.join(checkpoint_directory, "second")) + + first_root = checkpointable.Checkpointable() + second_root = checkpointable.Checkpointable() + first_status = checkpointable_utils.Saver(first_root).restore(first_path) + second_status = checkpointable_utils.Saver(second_root).restore(second_path) + load_dep = checkpointable.Checkpointable() + load_dep.var = checkpointable_utils.add_variable( + load_dep, name="var", shape=[]) + first_root.dep = load_dep + first_status.assert_consumed() + first_status.run_restore_ops() + self.assertEqual(12., self.evaluate(load_dep.var)) + second_root.dep = load_dep + second_status.assert_consumed() + second_status.run_restore_ops() + self.assertEqual(13., self.evaluate(load_dep.var)) + + # Try again with the order of the restore() reversed. The last restore + # determines the final value. + first_root = checkpointable.Checkpointable() + second_root = checkpointable.Checkpointable() + second_status = checkpointable_utils.Saver(second_root).restore(second_path) + first_status = checkpointable_utils.Saver(first_root).restore(first_path) + load_dep = checkpointable.Checkpointable() + load_dep.var = checkpointable_utils.add_variable( + load_dep, name="var", shape=[]) + first_root.dep = load_dep + first_status.assert_consumed() + first_status.run_restore_ops() + self.assertEqual(12., self.evaluate(load_dep.var)) + second_root.dep = load_dep + second_status.assert_consumed() + second_status.run_restore_ops() + self.assertEqual(12., self.evaluate(load_dep.var)) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testAmbiguousLoad(self): + # Not OK to split one checkpoint object into two + checkpoint_directory = self.get_temp_dir() + save_root = checkpointable.Checkpointable() + save_root.dep_one = checkpointable.Checkpointable() + save_root.dep_two = checkpointable.Checkpointable() + dep_three = checkpointable.Checkpointable() + save_root.dep_one.dep_three = dep_three + save_root.dep_two.dep_three = dep_three + checkpointable_utils.add_variable(dep_three, name="var", initializer=0.) + self.evaluate(checkpointable_utils.gather_initializers(save_root)) + save_path = checkpointable_utils.Saver(save_root).save( + os.path.join(checkpoint_directory, "ckpt")) + load_root = checkpointable.Checkpointable() + checkpointable_utils.Saver(load_root).restore(save_path) + load_root.dep_one = checkpointable.Checkpointable() + load_root.dep_two = checkpointable.Checkpointable() + load_root.dep_one.dep_three = checkpointable.Checkpointable() + with self.assertRaisesRegexp(AssertionError, + "resolved to different objects"): + load_root.dep_two.dep_three = checkpointable.Checkpointable() + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testObjectsCombined(self): + # Currently fine to load two checkpoint objects into one Python object + checkpoint_directory = self.get_temp_dir() + save_root = checkpointable.Checkpointable() + save_root.dep_one = checkpointable.Checkpointable() + save_root.dep_two = checkpointable.Checkpointable() + checkpointable_utils.add_variable( + save_root.dep_one, name="var1", initializer=32., dtype=dtypes.float64) + checkpointable_utils.add_variable( + save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64) + self.evaluate(checkpointable_utils.gather_initializers(save_root)) + save_path = checkpointable_utils.Saver(save_root).save( + os.path.join(checkpoint_directory, "ckpt")) + load_root = checkpointable.Checkpointable() + load_root.dep_one = checkpointable.Checkpointable() + load_root.dep_two = load_root.dep_one + v1 = checkpointable_utils.add_variable( + load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64) + v2 = checkpointable_utils.add_variable( + load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64) + status = checkpointable_utils.Saver(load_root).restore( + save_path).assert_consumed() + status.run_restore_ops() + self.assertEqual(32., self.evaluate(v1)) + self.assertEqual(64., self.evaluate(v2)) + + @test_util.run_in_graph_and_eager_modes() + def testDependencyLoop(self): + # Note: this test creates garbage during eager execution because it + # purposefully creates a reference cycle. + first = checkpointable.Checkpointable() + second = checkpointable.Checkpointable() + first.second = second + second.first = first + first.v = checkpointable_utils.add_variable( + first, "v1", initializer=[3., 1., 4.]) + second.v = checkpointable_utils.add_variable( + second, "v2", initializer=[1., 1., 2., 3.]) + self.evaluate(checkpointable_utils.gather_initializers(first)) + checkpoint_directory = self.get_temp_dir() + save_path = checkpointable_utils.Saver(first).save( + os.path.join(checkpoint_directory, "ckpt")) + + # Test deferred loading + first_load = checkpointable.Checkpointable() + status = checkpointable_utils.Saver(first_load).restore(save_path) + second_load = checkpointable.Checkpointable() + first_load.second = second_load + second_load.first = first_load + with self.assertRaises(AssertionError): + status.assert_consumed() + first_load.v = checkpointable_utils.add_variable( + first_load, "v1", shape=[3]) + second_load.v = checkpointable_utils.add_variable( + second_load, "v2", shape=[4]) + status.assert_consumed() + status.run_restore_ops() + self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) + self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) + + # Test loading when variables have already been created + self.evaluate(first_load.v.assign([2., 7., 1.])) + self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v)) + self.evaluate(second_load.v.assign([2., 7., 1., 8.])) + self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v)) + status = checkpointable_utils.Saver(first_load).restore( + save_path).assert_consumed() + status.run_restore_ops() + self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v)) + self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v)) + + @test_util.run_in_graph_and_eager_modes() + def testRestoreOnAssign(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session(save_graph): + first = checkpointable.Checkpointable() + first.var1 = variable_scope.get_variable( + name="outside_var", initializer=0.) + first.var2 = variable_scope.get_variable( + name="blah", initializer=0.) + self.evaluate(first.var1.assign(4.)) + self.evaluate(first.var2.assign(8.)) + save_path = checkpointable_utils.Saver(first).save( + checkpoint_prefix) + restore_graph = ops.Graph() + with restore_graph.as_default(), self.test_session(restore_graph): + second = checkpointable.Checkpointable() + second.var2 = variable_scope.get_variable( + name="blah", initializer=0.) + status = checkpointable_utils.Saver(second).restore(save_path) + recreated_var1 = variable_scope.get_variable( + name="outside_var", initializer=0.) + status.run_restore_ops() + self.assertEqual(8., self.evaluate(second.var2)) + self.evaluate(recreated_var1.assign(-2.)) + self.assertEqual(-2., self.evaluate(recreated_var1)) + second.var1 = recreated_var1 + status.run_restore_ops() + self.assertEqual(4., self.evaluate(recreated_var1)) + + def testManySavesGraph(self): + """Saves after the first should not modify the graph.""" + with context.graph_mode(): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + obj.opt = CheckpointableAdam(0.1) + obj.opt.minimize(obj.var.read_value()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.Saver(obj) + saver.save(checkpoint_prefix) + before_ops = graph.get_operations() + saver.save(checkpoint_prefix) + self.assertEqual(before_ops, graph.get_operations()) + + def testManyRestoresGraph(self): + """Restores after the first should not modify the graph.""" + with context.graph_mode(): + graph = ops.Graph() + with graph.as_default(), self.test_session(graph): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + obj = checkpointable.Checkpointable() + obj.var = variable_scope.get_variable(name="v", initializer=0.) + obj.opt = CheckpointableAdam(0.1) + obj.opt.minimize(obj.var.read_value()) + self.evaluate(checkpointable_utils.gather_initializers(obj)) + saver = checkpointable_utils.Saver(obj) + save_path = saver.save(checkpoint_prefix) + saver.restore(save_path) + before_ops = graph.get_operations() + saver.restore(save_path) + self.assertEqual(before_ops, graph.get_operations()) + + +class CheckpointCompatibilityTests(test.TestCase): + + def _initialized_model(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + optimizer_step = training_util.get_or_create_global_step() + root_checkpointable = Checkpoint( + optimizer=optimizer, network=network, optimizer_step=optimizer_step) + train_op = optimizer.minimize( + functools.partial(network, input_value), + global_step=optimizer_step) + self.evaluate(checkpointable_utils.gather_initializers( + root_checkpointable)) + self.evaluate(train_op) + # A regular variable, a slot variable, and a non-slot Optimizer variable + # with known values to check when loading. + self.evaluate(network._named_dense.bias.assign([1.])) + self.evaluate(optimizer.get_slot( + var=network._named_dense.bias, name="m").assign([2.])) + beta1_power, _ = optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(3.)) + return root_checkpointable + + def _set_sentinels(self, root_checkpointable): + self.evaluate(root_checkpointable.network._named_dense.bias.assign([101.])) + self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.network._named_dense.bias, name="m") + .assign([102.])) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.evaluate(beta1_power.assign(103.)) + + def _check_sentinels(self, root_checkpointable): + self.assertAllEqual( + [1.], self.evaluate(root_checkpointable.network._named_dense.bias)) + self.assertAllEqual([2.], self.evaluate( + root_checkpointable.optimizer.get_slot( + var=root_checkpointable.network._named_dense.bias, name="m"))) + beta1_power, _ = root_checkpointable.optimizer._get_beta_accumulators() + self.assertAllEqual(3., self.evaluate(beta1_power)) + + def _write_name_based_checkpoint(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + root = self._initialized_model() + name_saver = core_saver.Saver() + return name_saver.save( + sess=session, save_path=checkpoint_prefix, + global_step=root.optimizer_step) + + @test_util.run_in_graph_and_eager_modes() + def testLoadFromNameBasedSaver(self): + """Save a name-based checkpoint, load it using the object-based API.""" + save_path = self._write_name_based_checkpoint() + root = self._initialized_model() + self._set_sentinels(root) + with self.assertRaises(AssertionError): + self._check_sentinels(root) + object_saver = checkpointable_utils.Saver(root) + status = object_saver.restore(save_path) + with self.assertRaises(AssertionError): + status.assert_consumed() + status.run_restore_ops() + self._check_sentinels(root) + self._set_sentinels(root) + status.initialize_or_restore() + self._check_sentinels(root) + + # TODO(allenl): Test for the core name-based saver loading object-based + # checkpoints once object-based checkpointing is in core. + + def testSaveGraphLoadEager(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph) as session: + root = self._initialized_model() + object_saver = checkpointable_utils.Saver(root) + save_path = object_saver.save( + session=session, file_prefix=checkpoint_prefix) + with context.eager_mode(): + root = self._initialized_model() + self._set_sentinels(root) + root.restore(save_path).assert_consumed() + self._check_sentinels(root) + + def testSaveEagerLoadGraph(self): + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + with context.eager_mode(): + root = self._initialized_model() + object_saver = checkpointable_utils.Saver(root) + save_path = object_saver.save(file_prefix=checkpoint_prefix) + with context.graph_mode(): + save_graph = ops.Graph() + with save_graph.as_default(), self.test_session( + graph=save_graph): + root = self._initialized_model() + self._set_sentinels(root) + root.restore(save_path).assert_consumed().run_restore_ops() + self._check_sentinels(root) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a1611e92b113839c2dd2a3b2560b0ba90c0a7ef0..35c3c5d3fad0a84bbe4d24c7bb17878583bded4b 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,11 +16,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import threading import time import numpy as np from tensorflow.contrib import lookup +from tensorflow.contrib.data.python.ops import threadpool +from tensorflow.contrib.data.python.ops import unique from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test @@ -165,6 +168,38 @@ class IteratorTest(test.TestCase): x = math_ops.add(x, x) self.assertAllEqual([0., 2.], x.numpy()) + def testOverrideThreadPool(self): + + def get_thread_id(_): + # Python creates a dummy thread object to represent the current + # thread when called from an "alien" thread (such as a + # `PrivateThreadPool` thread in this case). It does not include + # the TensorFlow-given display name, but it has a unique + # identifier that maps one-to-one with the underlying OS thread. + return np.array(threading.current_thread().ident).astype(np.int64) + + for num_threads in [1, 2, 4, 8, 16]: + + dataset = ( + Dataset.range(1000).map( + lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64), + num_parallel_calls=32).apply(unique.unique())) + + dataset = threadpool.override_threadpool( + dataset, + threadpool.PrivateThreadPool( + num_threads, display_name='private_thread_pool_%d' % num_threads)) + + thread_ids = [] + for next_element in datasets.Iterator(dataset): + thread_ids.append(next_element) + self.assertEqual(len(thread_ids), len(set(thread_ids))) + self.assertGreater(len(thread_ids), 0) + # NOTE(mrry): We don't control the thread pool scheduling, and + # so cannot guarantee that all of the threads in the pool will + # perform work. + self.assertLessEqual(len(thread_ids), num_threads) + class DatasetConstructorBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index 15a21885f66eface291a39fa0ee1ff28bc297548..c1fd9e0ed020beeb722204edf1adfe1dfcf8ff03 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -8,7 +8,6 @@ py_library( deps = [ "//tensorflow/contrib/eager/python/examples/gan:mnist", "//tensorflow/contrib/eager/python/examples/linear_regression", - "//tensorflow/contrib/eager/python/examples/mnist", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py index b9ac79f46c83bb709918e3b72830b90ddcfd71b4..5f51d52622caedc6baa9f9f9950a6fd91761259a 100644 --- a/tensorflow/contrib/eager/python/examples/gan/mnist.py +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -35,7 +35,7 @@ from tensorflow.examples.tutorials.mnist import input_data FLAGS = None -class Discriminator(tfe.Network): +class Discriminator(tf.keras.Model): """GAN Discriminator. A network to differentiate between generated and real handwritten digits. @@ -56,19 +56,15 @@ class Discriminator(tfe.Network): else: assert data_format == 'channels_last' self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', - data_format=data_format, - activation=tf.tanh)) - self.pool1 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, - data_format=data_format, - activation=tf.tanh)) - self.pool2 = self.track_layer( - tf.layers.AveragePooling2D(2, 2, data_format=data_format)) - self.flatten = self.track_layer(tf.layers.Flatten()) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) - self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + self.conv1 = tf.layers.Conv2D( + 64, 5, padding='SAME', data_format=data_format, activation=tf.tanh) + self.pool1 = tf.layers.AveragePooling2D(2, 2, data_format=data_format) + self.conv2 = tf.layers.Conv2D( + 128, 5, data_format=data_format, activation=tf.tanh) + self.pool2 = tf.layers.AveragePooling2D(2, 2, data_format=data_format) + self.flatten = tf.layers.Flatten() + self.fc1 = tf.layers.Dense(1024, activation=tf.tanh) + self.fc2 = tf.layers.Dense(1, activation=None) def call(self, inputs): """Return two logits per image estimating input authenticity. @@ -95,7 +91,7 @@ class Discriminator(tfe.Network): return x -class Generator(tfe.Network): +class Generator(tf.keras.Model): """Generator of handwritten digits similar to the ones in the MNIST dataset. """ @@ -116,18 +112,17 @@ class Generator(tfe.Network): else: assert data_format == 'channels_last' self._pre_conv_shape = [-1, 6, 6, 128] - self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, - activation=tf.tanh)) + self.fc1 = tf.layers.Dense(6 * 6 * 128, activation=tf.tanh) # In call(), we reshape the output of fc1 to _pre_conv_shape # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) - self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( - 64, 4, strides=2, activation=None, data_format=data_format)) + self.conv1 = tf.layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format) # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) - self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( - 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + self.conv2 = tf.layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format) def call(self, inputs): """Return a batch of generated images. @@ -168,7 +163,8 @@ def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): """ loss_on_real = tf.losses.sigmoid_cross_entropy( - tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + tf.ones_like(discriminator_real_outputs), + discriminator_real_outputs, label_smoothing=0.25) loss_on_generated = tf.losses.sigmoid_cross_entropy( tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) @@ -198,9 +194,8 @@ def generator_loss(discriminator_gen_outputs): return loss -def train_one_epoch(generator, discriminator, - generator_optimizer, discriminator_optimizer, - dataset, log_interval, noise_dim): +def train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, dataset, log_interval, noise_dim): """Trains `generator` and `discriminator` models on `dataset`. Args: @@ -222,14 +217,18 @@ def train_one_epoch(generator, discriminator, with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): current_batch_size = images.shape[0] - noise = tf.random_uniform(shape=[current_batch_size, noise_dim], - minval=-1., maxval=1., seed=batch_index) + noise = tf.random_uniform( + shape=[current_batch_size, noise_dim], + minval=-1., + maxval=1., + seed=batch_index) with tfe.GradientTape(persistent=True) as g: generated_images = generator(noise) - tf.contrib.summary.image('generated_images', - tf.reshape(generated_images, [-1, 28, 28, 1]), - max_images=10) + tf.contrib.summary.image( + 'generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) discriminator_gen_outputs = discriminator(generated_images) discriminator_real_outputs = discriminator(images) @@ -245,17 +244,17 @@ def train_one_epoch(generator, discriminator, discriminator.variables) with tf.variable_scope('generator'): - generator_optimizer.apply_gradients(zip(generator_grad, - generator.variables)) + generator_optimizer.apply_gradients( + zip(generator_grad, generator.variables)) with tf.variable_scope('discriminator'): - discriminator_optimizer.apply_gradients(zip(discriminator_grad, - discriminator.variables)) + discriminator_optimizer.apply_gradients( + zip(discriminator_grad, discriminator.variables)) if log_interval and batch_index > 0 and batch_index % log_interval == 0: print('Batch #%d\tAverage Generator Loss: %.6f\t' - 'Average Discriminator Loss: %.6f' % ( - batch_index, total_generator_loss/batch_index, - total_discriminator_loss/batch_index)) + 'Average Discriminator Loss: %.6f' % + (batch_index, total_generator_loss / batch_index, + total_discriminator_loss / batch_index)) def main(_): @@ -266,10 +265,9 @@ def main(_): # Load the datasets data = input_data.read_data_sets(FLAGS.data_dir) - dataset = (tf.data.Dataset - .from_tensor_slices(data.train.images) - .shuffle(60000) - .batch(FLAGS.batch_size)) + dataset = ( + tf.data.Dataset.from_tensor_slices(data.train.images).shuffle(60000) + .batch(FLAGS.batch_size)) # Create the models and optimizers generator = Generator(data_format) @@ -294,20 +292,17 @@ def main(_): start = time.time() with summary_writer.as_default(): train_one_epoch(generator, discriminator, generator_optimizer, - discriminator_optimizer, - dataset, FLAGS.log_interval, FLAGS.noise) + discriminator_optimizer, dataset, FLAGS.log_interval, + FLAGS.noise) end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) + print('\nTrain time for epoch #%d (global step %d): %f' % + (epoch, global_step.numpy(), end - start)) all_variables = ( - generator.variables - + discriminator.variables - + generator_optimizer.variables() - + discriminator_optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) + generator.variables + discriminator.variables + + generator_optimizer.variables() + + discriminator_optimizer.variables() + [global_step]) + tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index 6ce4de6ee0bf50400eff339ac04e132252a2b53e..157a6360ea555bba37df008a6458acac0342880b 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -33,23 +33,13 @@ import tensorflow as tf import tensorflow.contrib.eager as tfe -class LinearModel(tfe.Network): - """A TensorFlow linear regression model. - - Uses TensorFlow's eager execution. - - For those familiar with TensorFlow graphs, notice the absence of - `tf.Session`. The `forward()` method here immediately executes and - returns output values. The `loss()` method immediately compares the - output of `forward()` with the target and returns the MSE loss value. - The `fit()` performs gradient-descent training on the model's weights - and bias. - """ +class LinearModel(tf.keras.Model): + """A TensorFlow linear regression model.""" def __init__(self): """Constructs a LinearModel object.""" super(LinearModel, self).__init__() - self._hidden_layer = self.track_layer(tf.layers.Dense(1)) + self._hidden_layer = tf.layers.Dense(1) def call(self, xs): """Invoke the linear model. diff --git a/tensorflow/contrib/eager/python/examples/mnist/BUILD b/tensorflow/contrib/eager/python/examples/mnist/BUILD deleted file mode 100644 index c61ec2dbae60a782c0e6589701554b045dcb92ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/BUILD +++ /dev/null @@ -1,36 +0,0 @@ -licenses(["notice"]) # Apache 2.0 - -package(default_visibility = ["//tensorflow:internal"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -py_binary( - name = "mnist", - srcs = ["mnist.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow:tensorflow_py", - "//tensorflow/contrib/eager/python:tfe", - "//tensorflow/examples/tutorials/mnist:input_data", - ], -) - -cuda_py_test( - name = "mnist_test", - srcs = ["mnist_test.py"], - additional_deps = [ - ":mnist", - "//tensorflow/contrib/eager/python:tfe", - "//tensorflow:tensorflow_py", - ], -) - -cuda_py_test( - name = "mnist_graph_test", - srcs = ["mnist_graph_test.py"], - additional_deps = [ - ":mnist", - "//third_party/py/numpy", - "//tensorflow:tensorflow_py", - ], -) diff --git a/tensorflow/contrib/eager/python/examples/mnist/README.md b/tensorflow/contrib/eager/python/examples/mnist/README.md index e987996b88ccf54a322749aadec4f9840760a90f..d1c079ff6b5cb187bbcfe2742293982b1bedd2d4 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/README.md +++ b/tensorflow/contrib/eager/python/examples/mnist/README.md @@ -1,10 +1 @@ -Classification model for the MNIST dataset using eager execution. - -To run: - -``` -python mnist.py -``` - -`mnist_graph_test.py` demonstrates that the same code that is executed eagerly -in `mnist.py` is used to construct a TensorFlow graph. +See https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py deleted file mode 100644 index 772f59562ba27cce510c82681f491d005298f44c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ /dev/null @@ -1,265 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""A deep MNIST classifier using convolutional layers. - -Sample usage: - python mnist.py --help -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import os -import sys -import time - -import tensorflow as tf - -import tensorflow.contrib.eager as tfe -from tensorflow.examples.tutorials.mnist import input_data - -FLAGS = None - - -class MNISTModel(tfe.Network): - """MNIST Network. - - Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py - and - https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py - - But written using the tf.layers API. - """ - - def __init__(self, data_format): - """Creates a model for classifying a hand-written digit. - - Args: - data_format: Either 'channels_first' or 'channels_last'. - 'channels_first' is typically faster on GPUs while 'channels_last' is - typically faster on CPUs. See - https://www.tensorflow.org/performance/performance_guide#data_formats - """ - super(MNISTModel, self).__init__(name='') - if data_format == 'channels_first': - self._input_shape = [-1, 1, 28, 28] - else: - assert data_format == 'channels_last' - self._input_shape = [-1, 28, 28, 1] - self.conv1 = self.track_layer( - tf.layers.Conv2D(32, 5, data_format=data_format, activation=tf.nn.relu)) - self.conv2 = self.track_layer( - tf.layers.Conv2D(64, 5, data_format=data_format, activation=tf.nn.relu)) - self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.nn.relu)) - self.fc2 = self.track_layer(tf.layers.Dense(10)) - self.dropout = self.track_layer(tf.layers.Dropout(0.5)) - self.max_pool2d = self.track_layer( - tf.layers.MaxPooling2D( - (2, 2), (2, 2), padding='SAME', data_format=data_format)) - - def call(self, inputs, training): - """Computes labels from inputs. - - Users should invoke __call__ to run the network, which delegates to this - method (and not call this method directly). - - Args: - inputs: A batch of images as a Tensor with shape [batch_size, 784]. - training: True if invoked in the context of training (causing dropout to - be applied). False otherwise. - - Returns: - A Tensor with shape [batch_size, 10] containing the predicted logits - for each image in the batch, for each of the 10 classes. - """ - - x = tf.reshape(inputs, self._input_shape) - x = self.conv1(x) - x = self.max_pool2d(x) - x = self.conv2(x) - x = self.max_pool2d(x) - x = tf.layers.flatten(x) - x = self.fc1(x) - x = self.dropout(x, training=training) - x = self.fc2(x) - return x - - -def loss(predictions, labels): - return tf.reduce_mean( - tf.nn.softmax_cross_entropy_with_logits( - logits=predictions, labels=labels)) - - -def compute_accuracy(predictions, labels): - return tf.reduce_sum( - tf.cast( - tf.equal( - tf.argmax(predictions, axis=1, - output_type=tf.int64), - tf.argmax(labels, axis=1, - output_type=tf.int64)), - dtype=tf.float32)) / float(predictions.shape[0].value) - - -def train_one_epoch(model, optimizer, dataset, log_interval=None): - """Trains model on `dataset` using `optimizer`.""" - - tf.train.get_or_create_global_step() - - for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): - with tf.contrib.summary.record_summaries_every_n_global_steps(10): - with tfe.GradientTape() as tape: - prediction = model(images, training=True) - loss_value = loss(prediction, labels) - tf.contrib.summary.scalar('loss', loss_value) - tf.contrib.summary.scalar('accuracy', - compute_accuracy(prediction, labels)) - grads = tape.gradient(loss_value, model.variables) - optimizer.apply_gradients(zip(grads, model.variables)) - if log_interval and batch % log_interval == 0: - print('Batch #%d\tLoss: %.6f' % (batch, loss_value)) - - -def test(model, dataset): - """Perform an evaluation of `model` on the examples from `dataset`.""" - avg_loss = tfe.metrics.Mean('loss') - accuracy = tfe.metrics.Accuracy('accuracy') - - for (images, labels) in tfe.Iterator(dataset): - predictions = model(images, training=False) - avg_loss(loss(predictions, labels)) - accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64), - tf.argmax(labels, axis=1, output_type=tf.int64)) - print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' % - (avg_loss.result(), 100 * accuracy.result())) - with tf.contrib.summary.always_record_summaries(): - tf.contrib.summary.scalar('loss', avg_loss.result()) - tf.contrib.summary.scalar('accuracy', accuracy.result()) - - -def load_data(data_dir): - """Returns training and test tf.data.Dataset objects.""" - data = input_data.read_data_sets(data_dir, one_hot=True) - train_ds = tf.data.Dataset.from_tensor_slices((data.train.images, - data.train.labels)) - test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels)) - return (train_ds, test_ds) - - -def main(_): - tfe.enable_eager_execution() - - (device, data_format) = ('/gpu:0', 'channels_first') - if FLAGS.no_gpu or tfe.num_gpus() <= 0: - (device, data_format) = ('/cpu:0', 'channels_last') - print('Using device %s, and data format %s.' % (device, data_format)) - - # Load the datasets - (train_ds, test_ds) = load_data(FLAGS.data_dir) - train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size) - - # Create the model and optimizer - model = MNISTModel(data_format) - optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) - - if FLAGS.output_dir: - train_dir = os.path.join(FLAGS.output_dir, 'train') - test_dir = os.path.join(FLAGS.output_dir, 'eval') - tf.gfile.MakeDirs(FLAGS.output_dir) - else: - train_dir = None - test_dir = None - summary_writer = tf.contrib.summary.create_file_writer( - train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_file_writer( - test_dir, flush_millis=10000, name='test') - checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') - - with tf.device(device): - for epoch in range(1, 11): - with tfe.restore_variables_on_create( - tf.train.latest_checkpoint(FLAGS.checkpoint_dir)): - global_step = tf.train.get_or_create_global_step() - start = time.time() - with summary_writer.as_default(): - train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval) - end = time.time() - print('\nTrain time for epoch #%d (global step %d): %f' % ( - epoch, global_step.numpy(), end - start)) - with test_summary_writer.as_default(): - test(model, test_ds) - all_variables = ( - model.variables - + optimizer.variables() - + [global_step]) - tfe.Saver(all_variables).save( - checkpoint_prefix, global_step=global_step) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument( - '--data-dir', - type=str, - default='/tmp/tensorflow/mnist/input_data', - help='Directory for storing input data') - parser.add_argument( - '--batch-size', - type=int, - default=64, - metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument( - '--log-interval', - type=int, - default=10, - metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument( - '--output_dir', - type=str, - default=None, - metavar='N', - help='Directory to write TensorBoard summaries') - parser.add_argument( - '--checkpoint_dir', - type=str, - default='/tmp/tensorflow/mnist/checkpoints/', - metavar='N', - help='Directory to save checkpoints in (once per epoch)') - parser.add_argument( - '--lr', - type=float, - default=0.01, - metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument( - '--momentum', - type=float, - default=0.5, - metavar='M', - help='SGD momentum (default: 0.5)') - parser.add_argument( - '--no-gpu', - action='store_true', - default=False, - help='disables GPU usage even if a GPU is available') - - FLAGS, unparsed = parser.parse_known_args() - tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py deleted file mode 100644 index 1af26553120b34d4682b17b1c29c81dc65e421d4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -import tensorflow as tf -from tensorflow.contrib.eager.python.examples.mnist import mnist - - -def data_format(): - return "channels_first" if tf.test.is_gpu_available() else "channels_last" - - -class MNISTGraphTest(tf.test.TestCase): - - def testTrainGraph(self): - # The MNISTModel class can be executed eagerly (as in mnist.py and - # mnist_test.py) and also be used to construct a TensorFlow graph, which is - # then trained in a session. - with tf.Graph().as_default(): - # Generate some random data. - batch_size = 64 - images = np.random.randn(batch_size, 784).astype(np.float32) - digits = np.random.randint(low=0, high=10, size=batch_size) - labels = np.zeros((batch_size, 10)) - labels[np.arange(batch_size), digits] = 1. - - # Create a model, optimizer, and dataset as would be done - # for eager execution as well. - model = mnist.MNISTModel(data_format()) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) - dataset = tf.data.Dataset.from_tensors((images, labels)) - - # Define the loss tensor (as opposed to a loss function when - # using eager execution). - (images, labels) = dataset.make_one_shot_iterator().get_next() - predictions = model(images, training=True) - loss = mnist.loss(predictions, labels) - - train_op = optimizer.minimize(loss) - init = tf.global_variables_initializer() - with tf.Session() as sess: - # Variables have to be initialized in the session. - sess.run(init) - # Train using the optimizer. - sess.run(train_op) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py deleted file mode 100644 index 136085eba21284a42282395e54f32c33bf63b5c3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - -import tensorflow.contrib.eager as tfe -from tensorflow.contrib.eager.python.examples.mnist import mnist - - -def device(): - return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" - - -def data_format(): - return "channels_first" if tfe.num_gpus() else "channels_last" - - -def random_dataset(): - batch_size = 64 - images = tf.random_normal([batch_size, 784]) - digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32) - labels = tf.one_hot(digits, 10) - return tf.data.Dataset.from_tensors((images, labels)) - - -def train_one_epoch(defun=False): - model = mnist.MNISTModel(data_format()) - if defun: - model.call = tfe.defun(model.call) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.train_one_epoch(model, optimizer, dataset) - - -def evaluate(defun=False): - model = mnist.MNISTModel(data_format()) - dataset = random_dataset() - if defun: - model.call = tfe.defun(model.call) - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.test(model, dataset) - - -class MNISTTest(tf.test.TestCase): - - def testTrainOneEpoch(self): - train_one_epoch(defun=False) - - def testTest(self): - evaluate(defun=False) - - def testTrainOneEpochWithDefunCall(self): - train_one_epoch(defun=True) - - def testTestWithDefunCall(self): - evaluate(defun=True) - - -if __name__ == "__main__": - tfe.enable_eager_execution() - tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index 9982fdb07eefa665379e7be095f4f8017d92cf97..6b59413141f78fc85474850e109454ecdeb68cd3 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -27,10 +27,9 @@ from __future__ import print_function import functools import tensorflow as tf -import tensorflow.contrib.eager as tfe -class _IdentityBlock(tfe.Network): +class _IdentityBlock(tf.keras.Model): """_IdentityBlock is the block that has no conv layer at shortcut. Args: @@ -50,31 +49,24 @@ class _IdentityBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - data_format=data_format, - name=conv_name_base + '2b')) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) + self.conv2a = tf.layers.Conv2D( + filters1, (1, 1), name=conv_name_base + '2a', data_format=data_format) + self.bn2a = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = tf.layers.Conv2D( + filters2, + kernel_size, + padding='same', + data_format=data_format, + name=conv_name_base + '2b') + self.bn2b = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = tf.layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -92,7 +84,7 @@ class _IdentityBlock(tfe.Network): return tf.nn.relu(x) -class _ConvBlock(tfe.Network): +class _ConvBlock(tf.keras.Model): """_ConvBlock is the block that has a conv layer at shortcut. Args: @@ -121,41 +113,35 @@ class _ConvBlock(tfe.Network): bn_name_base = 'bn' + str(stage) + block + '_branch' bn_axis = 1 if data_format == 'channels_first' else 3 - self.conv2a = self.track_layer( - tf.layers.Conv2D( - filters1, (1, 1), - strides=strides, - name=conv_name_base + '2a', - data_format=data_format)) - self.bn2a = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) - - self.conv2b = self.track_layer( - tf.layers.Conv2D( - filters2, - kernel_size, - padding='same', - name=conv_name_base + '2b', - data_format=data_format)) - self.bn2b = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) - - self.conv2c = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - name=conv_name_base + '2c', - data_format=data_format)) - self.bn2c = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) - - self.conv_shortcut = self.track_layer( - tf.layers.Conv2D( - filters3, (1, 1), - strides=strides, - name=conv_name_base + '1', - data_format=data_format)) - self.bn_shortcut = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')) + self.conv2a = tf.layers.Conv2D( + filters1, (1, 1), + strides=strides, + name=conv_name_base + '2a', + data_format=data_format) + self.bn2a = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2a') + + self.conv2b = tf.layers.Conv2D( + filters2, + kernel_size, + padding='same', + name=conv_name_base + '2b', + data_format=data_format) + self.bn2b = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2b') + + self.conv2c = tf.layers.Conv2D( + filters3, (1, 1), name=conv_name_base + '2c', data_format=data_format) + self.bn2c = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '2c') + + self.conv_shortcut = tf.layers.Conv2D( + filters3, (1, 1), + strides=strides, + name=conv_name_base + '1', + data_format=data_format) + self.bn_shortcut = tf.layers.BatchNormalization( + axis=bn_axis, name=bn_name_base + '1') def call(self, input_tensor, training=False): x = self.conv2a(input_tensor) @@ -176,7 +162,8 @@ class _ConvBlock(tfe.Network): return tf.nn.relu(x) -class ResNet50(tfe.Network): +# pylint: disable=not-callable +class ResNet50(tf.keras.Model): """Instantiates the ResNet50 architecture. Args: @@ -220,32 +207,28 @@ class ResNet50(tfe.Network): self.include_top = include_top def conv_block(filters, stage, block, strides=(2, 2)): - l = _ConvBlock( + return _ConvBlock( 3, filters, stage=stage, block=block, data_format=data_format, strides=strides) - return self.track_layer(l) def id_block(filters, stage, block): - l = _IdentityBlock( + return _IdentityBlock( 3, filters, stage=stage, block=block, data_format=data_format) - return self.track_layer(l) - - self.conv1 = self.track_layer( - tf.layers.Conv2D( - 64, (7, 7), - strides=(2, 2), - data_format=data_format, - padding='same', - name='conv1')) + + self.conv1 = tf.layers.Conv2D( + 64, (7, 7), + strides=(2, 2), + data_format=data_format, + padding='same', + name='conv1') bn_axis = 1 if data_format == 'channels_first' else 3 - self.bn_conv1 = self.track_layer( - tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')) - self.max_pool = self.track_layer( - tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format)) + self.bn_conv1 = tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1') + self.max_pool = tf.layers.MaxPooling2D( + (3, 3), strides=(2, 2), data_format=data_format) self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) self.l2b = id_block([64, 64, 256], stage=2, block='b') @@ -267,13 +250,11 @@ class ResNet50(tfe.Network): self.l5b = id_block([512, 512, 2048], stage=5, block='b') self.l5c = id_block([512, 512, 2048], stage=5, block='c') - self.avg_pool = self.track_layer( - tf.layers.AveragePooling2D( - (7, 7), strides=(7, 7), data_format=data_format)) + self.avg_pool = tf.layers.AveragePooling2D( + (7, 7), strides=(7, 7), data_format=data_format) if self.include_top: - self.fc1000 = self.track_layer( - tf.layers.Dense(classes, name='fc1000')) + self.fc1000 = tf.layers.Dense(classes, name='fc1000') else: reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] reduction_indices = tf.constant(reduction_indices) @@ -288,7 +269,7 @@ class ResNet50(tfe.Network): else: self.global_pooling = None - def call(self, input_tensor, training=False): + def call(self, input_tensor, training): x = self.conv1(input_tensor) x = self.bn_conv1(x, training=training) x = tf.nn.relu(x) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 23317886e712323f4b520000e0fd372734fc53a1..551c76b0df71c88919df9cd6d81b4176b23b0ba3 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -55,7 +55,7 @@ class ResNet50GraphTest(tf.test.TestCase): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() @@ -114,7 +114,7 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) - predictions = model(images) + predictions = model(images, training=False) init = tf.global_variables_initializer() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 0ff8746884c288f824f5f22ab4c550370d0e0302..65dcc53aab39670cae10846b6996c17d7b4c5ba8 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -71,7 +71,7 @@ class ResNet50Test(tf.test.TestCase): model.call = tfe.defun(model.call) with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) self.assertEqual((2, 1000), output.shape) def test_apply(self): @@ -85,7 +85,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format, include_top=False) with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) output_shape = ((2, 2048, 1, 1) if data_format == 'channels_first' else (2, 1, 1, 2048)) self.assertEqual(output_shape, output.shape) @@ -95,7 +95,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') with tf.device(device): images, _ = random_batch(2) - output = model(images) + output = model(images, training=False) self.assertEqual((2, 2048), output.shape) def test_train(self): @@ -194,11 +194,11 @@ class ResNet50Benchmarks(tf.test.Benchmark): with tf.device(device): images, _ = random_batch(batch_size) for _ in xrange(num_burn): - model(images).cpu() + model(images, training=False).cpu() gc.collect() start = time.time() for _ in xrange(num_iters): - model(images).cpu() + model(images, training=False).cpu() self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_apply(self): diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD index 21055cfe1110b565d7ef658b6f9024b7cdb9669a..a1f8a759e2a556bc219f0aa13942f293c4f34cfa 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/BUILD +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -38,9 +38,5 @@ cuda_py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_test_lib", ], - tags = [ - "manual", - "no_gpu", - "no_pip", # because spinn.py is under third_party/. - ], + tags = ["no_pip"], # because spinn.py is under third_party/. ) diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py index fcaae0a4f8c0bad916d74bd9b80fcfa55a63d84a..3bc3bb49bcbbc26f7a3134a8bfc385ec080dde1e 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -227,6 +227,29 @@ def calculate_bins(length2count, min_bin_size): return bounds +def encode_sentence(sentence, word2index): + """Encode a single sentence as word indices and shift-reduce code. + + Args: + sentence: The sentence with added binary parse information, represented as + a string, with all the word items and parentheses separated by spaces. + E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'. + word2index: A `dict` mapping words to their word indices. + + Returns: + 1. Word indices as a numpy array, with shape `(sequence_len, 1)`. + 2. Shift-reduce sequence as a numpy array, with shape + `(sequence_len * 2 - 3, 1)`. + """ + items = [w for w in sentence.split(" ") if w] + words = get_non_parenthesis_words(items) + shift_reduce = get_shift_reduce(items) + word_indices = pad_and_reverse_word_ids( + [[word2index.get(word, UNK_CODE) for word in words]]).T + return (word_indices, + np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1)) + + class SnliData(object): """A split of SNLI data.""" diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py index e4f0b37c5099e45b7e3b258b258c0a203c36b3b7..54fef2c3fe4111cd2d93ac109a5b8fffad0c2fad 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/data_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -22,6 +22,7 @@ import os import shutil import tempfile +import numpy as np import tensorflow as tf from tensorflow.contrib.eager.python.examples.spinn import data @@ -173,14 +174,9 @@ class DataTest(tf.test.TestCase): ValueError, "Cannot find GloVe embedding file at"): data.load_word_vectors(self._temp_data_dir, vocab) - def testSnliData(self): - """Unit test for SnliData objects.""" - snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") - fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") - os.makedirs(snli_1_0_dir) - + def _createFakeSnliData(self, fake_snli_file): # Four sentences in total. - with open(fake_train_file, "wt") as f: + with open(fake_snli_file, "wt") as f: f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") @@ -205,10 +201,7 @@ class DataTest(tf.test.TestCase): "4705552913.jpg#2\t4705552913.jpg#2r1n\t" "neutral\tentailment\tneutral\tneutral\tneutral\n") - glove_dir = os.path.join(self._temp_data_dir, "glove") - os.makedirs(glove_dir) - glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") - + def _createFakeGloveData(self, glove_file): words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] with open(glove_file, "wt") as f: for i, word in enumerate(words): @@ -220,6 +213,40 @@ class DataTest(tf.test.TestCase): else: f.write("\n") + def testEncodeSingleSentence(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + vocab = data.load_vocabulary(self._temp_data_dir) + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + sentence_variants = [ + "( Foo ( ( bar baz ) . ) )", + " ( Foo ( ( bar baz ) . ) ) ", + "( Foo ( ( bar baz ) . ) )"] + for sentence in sentence_variants: + word_indices, shift_reduce = data.encode_sentence(sentence, word2index) + self.assertEqual(np.int64, word_indices.dtype) + self.assertEqual((5, 1), word_indices.shape) + self.assertAllClose( + np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce) + + def testSnliData(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + vocab = data.load_vocabulary(self._temp_data_dir) word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) @@ -230,7 +257,7 @@ class DataTest(tf.test.TestCase): self.assertEqual(1, train_data.num_batches(4)) generator = train_data.get_generator(2)() - for i in range(2): + for _ in range(2): label, prem, prem_trans, hypo, hypo_trans = next(generator) self.assertEqual(2, len(label)) self.assertEqual((4, 2), prem.shape) diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py index 7b2f09cba16311ad36f3fa33b97111c4127fef33..081b0af14fcc983a3f85d2a50e2bb04d2f2493b3 100644 --- a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -36,6 +36,7 @@ from third_party.examples.eager.spinn import spinn from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import test from tensorflow.python.framework import test_util +from tensorflow.python.training import checkpoint_utils # pylint: enable=g-bad-import-order @@ -66,13 +67,30 @@ def _generate_synthetic_snli_data_batch(sequence_length, return labels, prem, prem_trans, hypo, hypo_trans -def _test_spinn_config(d_embed, d_out, logdir=None): +def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None): + """Generate a config tuple for testing. + + Args: + d_embed: Embedding dimensions. + d_out: Model output dimensions. + logdir: Optional logdir. + inference_sentences: A 2-tuple of strings representing the sentences (with + binary parsing result), e.g., + ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )"). + + Returns: + A config tuple. + """ config_tuple = collections.namedtuple( "Config", ["d_hidden", "d_proj", "d_tracker", "predict", "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", "d_out", "projection", "lr", "batch_size", "epochs", "force_cpu", "logdir", "log_every", "dev_every", "save_every", - "lr_decay_every", "lr_decay_by"]) + "lr_decay_every", "lr_decay_by", "inference_premise", + "inference_hypothesis"]) + + inference_premise = inference_sentences[0] if inference_sentences else None + inference_hypothesis = inference_sentences[1] if inference_sentences else None return config_tuple( d_hidden=d_embed, d_proj=d_embed * 2, @@ -86,14 +104,16 @@ def _test_spinn_config(d_embed, d_out, logdir=None): projection=True, lr=2e-2, batch_size=2, - epochs=10, + epochs=20, force_cpu=False, logdir=logdir, log_every=1, dev_every=2, save_every=2, lr_decay_every=1, - lr_decay_by=0.75) + lr_decay_by=0.75, + inference_premise=inference_premise, + inference_hypothesis=inference_hypothesis) class SpinnTest(test_util.TensorFlowTestCase): @@ -288,11 +308,7 @@ class SpinnTest(test_util.TensorFlowTestCase): # Training on the batch should have led to a change in the loss value. self.assertNotEqual(loss1.numpy(), loss2.numpy()) - def testTrainSpinn(self): - """Test with fake toy SNLI data and GloVe vectors.""" - - # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. - snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + def _create_test_data(self, snli_1_0_dir): fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") os.makedirs(snli_1_0_dir) @@ -337,13 +353,52 @@ class SpinnTest(test_util.TensorFlowTestCase): else: f.write("\n") + return fake_train_file + + def testInferSpinnWorks(self): + """Test inference with the spinn model.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )")) + logits = spinn.train_or_infer_spinn( + embed, word2index, None, None, None, config) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((3,), logits.shape) + + def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", None)) + with self.assertRaises(ValueError): + spinn.train_or_infer_spinn(embed, word2index, None, None, None, config) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = self._create_test_data(snli_1_0_dir) + vocab = data.load_vocabulary(self._temp_data_dir) word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) train_data = data.SnliData(fake_train_file, word2index) dev_data = data.SnliData(fake_train_file, word2index) test_data = data.SnliData(fake_train_file, word2index) - print(embed) # 2. Create a fake config. config = _test_spinn_config( @@ -351,7 +406,8 @@ class SpinnTest(test_util.TensorFlowTestCase): logdir=os.path.join(self._temp_data_dir, "logdir")) # 3. Test training of a SPINN model. - spinn.train_spinn(embed, train_data, dev_data, test_data, config) + trainer = spinn.train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) # 4. Load train loss values from the summary files and verify that they # decrease with training. @@ -363,6 +419,15 @@ class SpinnTest(test_util.TensorFlowTestCase): self.assertEqual(config.epochs, len(train_losses)) self.assertLess(train_losses[-1], train_losses[0]) + # 5. Verify that checkpoints exist and contains all the expected variables. + self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) + ckpt_variable_names = [ + item[0] for item in checkpoint_utils.list_variables(config.logdir)] + self.assertIn("global_step", ckpt_variable_names) + for v in trainer.variables: + variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name + self.assertIn(variable_name, ckpt_variable_names) + class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index ffc1d0332eae605ce0444a225e53baa68954cae0..ebb05051f27841f1cd3d21b6218986e774ed4c9f 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -22,11 +22,10 @@ to models defined without using eager execution. Eager execution is included in TensorFlow versions 1.5 and above. Installation instructions at https://www.tensorflow.org/install/ -The contents of this guide are compatible with TensorFlow 1.5. -However, if you run into bugs that are fixed in source but not the -release, you may want to either either [building from -source](https://www.tensorflow.org/install/install_sources) -or the try latest nightly builds. The nightly builds are available as: +The contents of this guide are compatible with TensorFlow 1.5. However, if you +run into bugs that are fixed in source but not the release, you may want to +either [build from source](https://www.tensorflow.org/install/install_sources) +or try a nightly build. The nightly builds are available as: - [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and @@ -570,8 +569,8 @@ for i in range(20001): print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy()) ``` -For a more complete example, see -[`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) +For a more complete example, see [the example in the tensorflow/models +repository](https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py). ### Checkpointing trained variables @@ -860,11 +859,9 @@ eagerly or constructing graphs. This means that you can iteratively develop your model with eager execution enabled and later, if needed, use the same code to reap the benefits of representing models as computational graphs. -For example, -[`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) -defines a model that is eagerly executed. That same code is used to construct -and execute a graph in -[`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py). +For example, the same model definition used to construct a graph in +[mnist.py`](https://github.com/tensorflow/models/tree/master/official/mnist/mnist.py) +can be trained with eager execution enabled as in [`mnist_eager.py`](https://github.com/tensorflow/models/tree/master/official/mnist/mnist_eager.py). Other models in the [examples directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/) diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 6cdbed5b896577f5622b1bd0123c289c798bc0a5..ddccfce3c07d20bde78de297db25437a347d75cb 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -138,6 +138,7 @@ py_test( size = "medium", srcs = ["python/estimator/extenders_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], # b/62863147 deps = [ ":extenders", "//tensorflow/contrib/data/python/ops:dataset_ops", diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 238cf287b768eee28b20202084eb244c085c8b75..a45f6934cc5b9bb7bccf148edbd7553b702c2127 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -177,6 +177,7 @@ def regression_head(weight_column=None, label_dimension=1, loss_reduction=losses.Reduction.SUM, loss_fn=None, + inverse_link_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -195,10 +196,16 @@ def regression_head(weight_column=None, `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. - Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, label_dimension]`. + Also supports custom `inverse_link_fn`, also known as 'mean function'. + `inverse_link_fn` takes `logits` as argument and returns predicted values. + This function is the inverse of the link function defined in + https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function + Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -209,7 +216,9 @@ def regression_head(weight_column=None, `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. - loss_fn: Optional loss function. + loss_fn: Optional loss function. Defaults to `mean_squared_error`. + inverse_link_fn: Optional inverse link function, also known as 'mean + function'. Defaults to identity. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -224,6 +233,7 @@ def regression_head(weight_column=None, label_dimension=label_dimension, loss_reduction=loss_reduction, loss_fn=loss_fn, + inverse_link_fn=inverse_link_fn, name=name) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 43cdfec9689879201305385499b3b784e1593d60..1411635228457218578c0297d4d901e9c86ca91a 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -446,7 +446,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -478,7 +478,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -509,7 +509,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -543,7 +543,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, } self._test_eval( head=head, @@ -573,7 +573,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.3333, - keys.AUC_PR: 0.7639, + keys.AUC_PR: 0.5972, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3., keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3., @@ -621,7 +621,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.2000, - keys.AUC_PR: 0.7833, + keys.AUC_PR: 0.5833, } # Assert spec contains expected tensors. @@ -1095,7 +1095,7 @@ class MultiLabelHead(test.TestCase): # auc and auc_pr cannot be reliably calculated for only 4 samples, but # this assert tests that the algorithm remains consistent. keys.AUC: 0.4977, - keys.AUC_PR: 0.6645, + keys.AUC_PR: 0.4037, } self._test_eval( head=head, diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 65ea89ba1b9236d0bf4d2de430fab168ef50bf97..e47a6788f3b5440c4906b9f0430c802cf73237e3 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -306,8 +306,8 @@ class MultiHeadTest(test.TestCase): # this assert tests that the algorithm remains consistent. keys.AUC + '/head1': 0.1667, keys.AUC + '/head2': 0.3333, - keys.AUC_PR + '/head1': 0.6667, - keys.AUC_PR + '/head2': 0.5000, + keys.AUC_PR + '/head1': 0.49999964, + keys.AUC_PR + '/head2': 0.33333313, } # Assert spec contains expected tensors. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index dfae034afc9a115dcc97e401e8a6d9c66a9b46e9..e0fae2c99292385c6dd32cc6002cee2076a2bb20 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -110,7 +110,8 @@ def replicate_model_fn(model_fn, Certain algorithms were chosen for aggregating results of computations on multiple towers: - Losses from all towers are reduced according to `loss_reduction`. - - Gradients are reduced using sum for each trainable variable. + - Gradients from all towers are reduced according to `loss_reduction` + for each trainable variable. - `eval_metrics_ops` are reduced per metric using `reduce_mean`. - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are reduced using concatenation. @@ -790,7 +791,7 @@ def _extract_tensors(tensors_and_vars): tensor, _ = tensor_and_var if isinstance(tensor, ops_lib.IndexedSlices): tensors.append(tensor.values) - else: + elif tensor is not None: tensors.append(tensor) return tensors diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index ab117e61a7059a224ebf6ff0355ae10363b758f5..d46a18aacfcd911c56a9f22dc9581060c7b458a6 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -240,6 +240,13 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): labels = np.array([[1.0], [2.0]]) with self.test_session() as session: + # Add another trainable variable that doesn't produce a gradient to + # verify that None gradients are supported. + _ = variable_scope.get_variable( + 'another_variable', + initializer=constant_op.constant(1, dtype=dtypes.float64), + dtype=dtypes.float64) + replicated_model_fn = replicate_model_fn.replicate_model_fn( self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( @@ -1119,8 +1126,6 @@ class SplitBatchTest(test_util.TensorFlowTestCase): feature_shards, label_shards = replicate_model_fn._split_batch( features, labels, 2, device='/gpu:0') - print(feature_shards[0]['x'].eval()) - print(feature_shards[1]['x'].eval()) self.assertSparseValuesEqual( sparse_tensor.SparseTensorValue( indices=[[0, 0], [1, 0], [1, 1]], diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index c861cfff544a78617aa1ace730b50c094cf16330..7319eaa7de8db8e4677bdf64af3b0a72c1007a90 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -61,8 +61,8 @@ class _LossRelativeChangeHook(session_run_hook.SessionRunHook): loss = run_values.results assert loss is not None if self._prev_loss: - relative_change = (abs(loss - self._prev_loss) / - (1 + abs(self._prev_loss))) + relative_change = ( + abs(loss - self._prev_loss) / (1 + abs(self._prev_loss))) if relative_change < self._tolerance: run_context.request_stop() self._prev_loss = loss @@ -233,7 +233,57 @@ class _ModelFn(object): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + Example: + ``` + import numpy as np + import tensorflow as tf + + num_points = 100 + dimensions = 2 + points = np.random.uniform(0, 1000, [num_points, dimensions]) + + def input_fn(): + return tf.train.limit_epochs( + tf.convert_to_tensor(points, dtype=tf.float32), num_epochs=1) + + num_clusters = 5 + kmeans = tf.contrib.factorization.KMeansClustering( + num_clusters=num_clusters, use_mini_batch=False) + + # train + num_iterations = 10 + previous_centers = None + for _ in xrange(num_iterations): + kmeans.train(input_fn) + cluster_centers = kmeans.cluster_centers() + if previous_centers is not None: + print 'delta:', cluster_centers - previous_centers + previous_centers = cluster_centers + print 'score:', kmeans.score(input_fn) + print 'cluster centers:', cluster_centers + + # map the input points to their clusters + cluster_indices = list(kmeans.predict_cluster_index(input_fn)) + for i, point in enumerate(points): + cluster_index = cluster_indices[i] + center = cluster_centers[cluster_index] + print 'point:', point, 'is in cluster', cluster_index, 'centered at', center + ``` + + The `SavedModel` saved by the `export_savedmodel` method does not include the + cluster centers. However, the cluster centers may be retrieved by the + latest checkpoint saved during training. Specifically, + ``` + kmeans.cluster_centers() + ``` + is equivalent to + ``` + tf.train.load_variable( + kmeans.model_dir, KMeansClustering.CLUSTER_CENTERS_VAR_NAME) + ``` + """ # Valid values for the distance_metric constructor argument. SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE @@ -253,6 +303,9 @@ class KMeansClustering(estimator.Estimator): CLUSTER_INDEX = 'cluster_index' ALL_DISTANCES = 'all_distances' + # Variable name used by cluster_centers(). + CLUSTER_CENTERS_VAR_NAME = clustering_ops.CLUSTERS_VAR_NAME + def __init__(self, num_clusters, model_dir=None, @@ -406,4 +459,4 @@ class KMeansClustering(estimator.Estimator): def cluster_centers(self): """Returns the cluster centers.""" - return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) + return self.get_variable_value(KMeansClustering.CLUSTER_CENTERS_VAR_NAME) diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD index 6fc053759c58d30c24657dd22e7d12be46fc7a7e..a53e36c2d57114934b6843a05f97784aeaf82662 100644 --- a/tensorflow/contrib/feature_column/BUILD +++ b/tensorflow/contrib/feature_column/BUILD @@ -33,5 +33,34 @@ py_library( name = "sequential_feature_column", srcs = ["python/feature_column/sequential_feature_column.py"], srcs_version = "PY2AND3", - deps = [], + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:variable_scope", + "//tensorflow/python/feature_column", + ], +) + +py_test( + name = "sequential_feature_column_test", + srcs = ["python/feature_column/sequential_feature_column_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":sequential_feature_column", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", + "//tensorflow/python/feature_column", + "//third_party/py/numpy", + ], ) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py index 690a44ff4368663306733300a1ea70397fb93e1e..4ed7268e7a921284eed7767d870e56ecac39a3b1 100644 --- a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py @@ -12,8 +12,314 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experimental methods for tf.feature_column sequential input.""" +"""Experimental methods for tf.feature_column sequence input.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function + + +import abc +import collections + + +from tensorflow.python.feature_column import feature_column as fc +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import variable_scope + +# TODO(b/73160931): Fix pydoc. +# pylint: disable=g-doc-args,missing-docstring,protected-access +# TODO(b/73827486): Support SequenceExample. + + +def sequence_input_layer( + features, + feature_columns, + weight_collections=None, + trainable=True, + scope=None): + """"Builds input layer for sequence input. + + All `feature_columns` must be sequence dense columns with the same + `sequence_length`. The output of this method can be fed into sequence + networks, such as RNN. + + The output of this method is a 3D `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ from + batch to batch. + + If multiple `feature_columns` are given with `Di` `num_elements` each, their + outputs are concatenated. So, the final `Tensor` has shape + `[batch_size, T, D0 + D1 + ... + Dn]`. + + Example: + + ```python + rating = sequence_numeric_column('rating') + watches = sequence_categorical_column_with_identity( + 'watches', num_buckets=1000) + watches_embedding = embedding_column(watches, dimension=10) + columns = [rating, watches] + + features = tf.parse_example(..., features=make_parse_example_spec(columns)) + input_layer, sequence_length = sequence_input_layer(features, columns) + + rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size) + outputs, state = tf.nn.dynamic_rnn( + rnn_cell, inputs=input_layer, sequence_length=sequence_length) + ``` + + Returns: + An `(input_layer, sequence_length)` tuple where: + - input_layer: A float `Tensor` of shape `[batch_size, T, D]`. + `T` is the maximum sequence length for this batch, which could differ + from batch to batch. `D` is the sum of `num_elements` for all + `feature_columns`. + - sequence_length: An int `Tensor` of shape `[batch_size]`. The sequence + length for each example. + Raises: + ValueError: If any of the `feature_columns` is the wrong type. + """ + feature_columns = fc._clean_feature_columns(feature_columns) + for c in feature_columns: + if not isinstance(c, _SequenceDenseColumn): + raise ValueError( + 'All feature_columns must be of type _SequenceDenseColumn. ' + 'Given (type {}): {}'.format(type(c), c)) + + with variable_scope.variable_scope( + scope, default_name='sequence_input_layer', values=features.values()): + builder = fc._LazyBuilder(features) + output_tensors = [] + sequence_lengths = [] + ordered_columns = [] + for column in sorted(feature_columns, key=lambda x: x.name): + ordered_columns.append(column) + with variable_scope.variable_scope( + None, default_name=column._var_scope_name): + dense_tensor, sequence_length = column._get_sequence_dense_tensor( + builder, + weight_collections=weight_collections, + trainable=trainable) + # Flattens the final dimension to produce a 3D Tensor. + num_elements = column._variable_shape.num_elements() + shape = array_ops.shape(dense_tensor) + output_tensors.append( + array_ops.reshape( + dense_tensor, + shape=array_ops.concat([shape[:2], [num_elements]], axis=0))) + sequence_lengths.append(sequence_length) + fc._verify_static_batch_size_equality(output_tensors, ordered_columns) + # TODO(b/73160931): Verify sequence_length equality. + return array_ops.concat(output_tensors, -1), sequence_lengths[0] + + +# TODO(b/73160931): Add remaining categorical columns. +def sequence_categorical_column_with_identity( + key, num_buckets, default_value=None): + return _SequenceCategoricalColumn( + fc.categorical_column_with_identity( + key=key, + num_buckets=num_buckets, + default_value=default_value)) + + +# TODO(b/73160931): Merge with embedding_column +def _sequence_embedding_column( + categorical_column, dimension, initializer=None, ckpt_to_load_from=None, + tensor_name_in_ckpt=None, max_norm=None, trainable=True): + if not isinstance(categorical_column, _SequenceCategoricalColumn): + raise ValueError( + 'categorical_column must be of type _SequenceCategoricalColumn. ' + 'Given (type {}): {}'.format( + type(categorical_column), categorical_column)) + return _SequenceEmbeddingColumn( + fc.embedding_column( + categorical_column, + dimension=dimension, + initializer=initializer, + ckpt_to_load_from=ckpt_to_load_from, + tensor_name_in_ckpt=tensor_name_in_ckpt, + max_norm=max_norm, + trainable=trainable)) + + +def sequence_numeric_column( + key, + shape=(1,), + default_value=0., + dtype=dtypes.float32): + # TODO(b/73160931): Add validations. + return _SequenceNumericColumn( + key, + shape=shape, + default_value=default_value, + dtype=dtype) + + +class _SequenceDenseColumn(fc._FeatureColumn): + """Represents dense sequence data.""" + + __metaclass__ = abc.ABCMeta + + TensorSequenceLengthPair = collections.namedtuple( # pylint: disable=invalid-name + 'TensorSequenceLengthPair', ['dense_tensor', 'sequence_length']) + + @abc.abstractproperty + def _variable_shape(self): + """`TensorShape` without batch and sequence dimensions.""" + pass + + @abc.abstractmethod + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + """Returns a `TensorSequenceLengthPair`.""" + pass + + +def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1): + with ops.name_scope(None, 'sequence_length') as name_scope: + row_ids = sp_tensor.indices[:, 0] + column_ids = sp_tensor.indices[:, 1] + column_ids += array_ops.ones_like(column_ids) + seq_length = ( + math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements) + # If the last n rows do not have ids, seq_length will have shape + # [batch_size - n]. Pad the remaining values with zeros. + n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1] + padding = array_ops.zeros(n_pad, dtype=seq_length.dtype) + return array_ops.concat([seq_length, padding], axis=0, name=name_scope) + + +class _SequenceCategoricalColumn( + fc._CategoricalColumn, + collections.namedtuple( + '_SequenceCategoricalColumn', ['categorical_column'])): + + @property + def name(self): + return self.categorical_column.name + + @property + def _parse_example_spec(self): + return self.categorical_column._parse_example_spec + + def _transform_feature(self, inputs): + return self.categorical_column._transform_feature(inputs) + + @property + def _num_buckets(self): + return self.categorical_column._num_buckets + + def _get_sparse_tensors(self, inputs, weight_collections=None, + trainable=None): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) + id_tensor = sparse_tensors.id_tensor + weight_tensor = sparse_tensors.weight_tensor + # Expands final dimension, so that embeddings are not combined during + # embedding lookup. + check_id_rank = check_ops.assert_equal( + array_ops.rank(id_tensor), 2, + data=[ + 'Column {} expected ID tensor of rank 2. '.format(self.name), + 'id_tensor shape: ', array_ops.shape(id_tensor)]) + with ops.control_dependencies([check_id_rank]): + id_tensor = sparse_ops.sparse_reshape( + id_tensor, + shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0)) + if weight_tensor is not None: + check_weight_rank = check_ops.assert_equal( + array_ops.rank(weight_tensor), 2, + data=[ + 'Column {} expected weight tensor of rank 2.'.format(self.name), + 'weight_tensor shape:', array_ops.shape(weight_tensor)]) + with ops.control_dependencies([check_weight_rank]): + weight_tensor = sparse_ops.sparse_reshape( + weight_tensor, + shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0)) + return fc._CategoricalColumn.IdWeightPair(id_tensor, weight_tensor) + + def _sequence_length(self, inputs): + sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) + return _sequence_length_from_sparse_tensor(sparse_tensors.id_tensor) + + +class _SequenceEmbeddingColumn( + _SequenceDenseColumn, + collections.namedtuple('_SequenceEmbeddingColumn', ['embedding_column'])): + + @property + def name(self): + return self.embedding_column.name + + @property + def _parse_example_spec(self): + return self.embedding_column._parse_example_spec + + def _transform_feature(self, inputs): + return self.embedding_column._transform_feature(inputs) + + @property + def _variable_shape(self): + return self.embedding_column._variable_shape + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + dense_tensor = self.embedding_column._get_dense_tensor( + inputs=inputs, + weight_collections=weight_collections, + trainable=trainable) + sequence_length = self.embedding_column.categorical_column._sequence_length( + inputs) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + + +class _SequenceNumericColumn( + _SequenceDenseColumn, + collections.namedtuple( + '_SequenceNumericColumn', + ['key', 'shape', 'default_value', 'dtype'])): + + @property + def name(self): + return self.key + + @property + def _parse_example_spec(self): + return {self.key: parsing_ops.VarLenFeature(self.dtype)} + + def _transform_feature(self, inputs): + return inputs.get(self.key) + + @property + def _variable_shape(self): + return tensor_shape.TensorShape(self.shape) + + def _get_sequence_dense_tensor( + self, inputs, weight_collections=None, trainable=None): + # Do nothing with weight_collections and trainable since no variables are + # created in this function. + del weight_collections + del trainable + sp_tensor = inputs.get(self) + dense_tensor = sparse_ops.sparse_tensor_to_dense( + sp_tensor, default_value=self.default_value) + # Reshape into [batch_size, T, variable_shape]. + dense_shape = array_ops.concat( + [array_ops.shape(dense_tensor)[:1], [-1], self._variable_shape], + axis=0) + dense_tensor = array_ops.reshape(dense_tensor, shape=dense_shape) + sequence_length = _sequence_length_from_sparse_tensor( + sp_tensor, num_elements=self._variable_shape.num_elements()) + return _SequenceDenseColumn.TensorSequenceLengthPair( + dense_tensor=dense_tensor, sequence_length=sequence_length) + +# pylint: enable=g-doc-args,missing-docstring,protected-access diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py new file mode 100644 index 0000000000000000000000000000000000000000..59674869a27c3a40ab9cb3dcede384d1cda7ce27 --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column_test.py @@ -0,0 +1,471 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for sequential_feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.feature_column.python.feature_column import sequential_feature_column as sfc +from tensorflow.python.feature_column.feature_column import _LazyBuilder +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import test +from tensorflow.python.training import monitored_session + + +class SequenceInputLayerTest(test.TestCase): + + def test_embedding_column(self): + vocabulary_size = 3 + sparse_input_a = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + sparse_input_b = sparse_tensor.SparseTensorValue( + # example 0, ids [1] + # example 1, ids [2, 0] + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + + embedding_dimension_a = 2 + embedding_values_a = ( + (1., 2.), # id 0 + (3., 4.), # id 1 + (5., 6.) # id 2 + ) + embedding_dimension_b = 3 + embedding_values_b = ( + (11., 12., 13.), # id 0 + (14., 15., 16.), # id 1 + (17., 18., 19.) # id 2 + ) + def _get_initializer(embedding_dimension, embedding_values): + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + return _initializer + + expected_input_layer = [ + # example 0, ids_a [2], ids_b [1] + [[5., 6., 14., 15., 16.], [0., 0., 0., 0., 0.]], + # example 1, ids_a [0, 1], ids_b [2, 0] + [[1., 2., 17., 18., 19.], [3., 4., 11., 12., 13.]], + ] + expected_sequence_length = [1, 2] + + categorical_column_a = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column_a = sfc._sequence_embedding_column( + categorical_column_a, dimension=embedding_dimension_a, + initializer=_get_initializer(embedding_dimension_a, embedding_values_a)) + categorical_column_b = sfc.sequence_categorical_column_with_identity( + key='bbb', num_buckets=vocabulary_size) + embedding_column_b = sfc._sequence_embedding_column( + categorical_column_b, dimension=embedding_dimension_b, + initializer=_get_initializer(embedding_dimension_b, embedding_values_b)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={ + 'aaa': sparse_input_a, + 'bbb': sparse_input_b, + }, + # Test that columns are reordered alphabetically. + feature_columns=[embedding_column_b, embedding_column_a]) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('sequence_input_layer/aaa_embedding/embedding_weights:0', + 'sequence_input_layer/bbb_embedding/embedding_weights:0'), + tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values_a, global_vars[0].eval(session=sess)) + self.assertAllEqual(embedding_values_b, global_vars[1].eval(session=sess)) + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_input_layer = [ + [[0.], [1.]], + [[10.], [0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_numeric_column_multi_dim(self): + """Tests sequence_input_layer for multi-dimensional numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + # The output of numeric_column._get_dense_tensor should be flattened. + expected_input_layer = [ + [[0., 1., 2., 3.], [4., 5., 6., 7.]], + [[10., 11., 12., 13.], [0., 0., 0., 0.]], + ] + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + input_layer, sequence_length = sfc.sequence_input_layer( + features={'aaa': sparse_input}, + feature_columns=[numeric_column]) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(expected_input_layer, input_layer.eval(session=sess)) + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +def _assert_sparse_tensor_value(test_case, expected, actual): + test_case.assertEqual(np.int64, np.array(actual.indices).dtype) + test_case.assertAllEqual(expected.indices, actual.indices) + + test_case.assertEqual( + np.array(expected.values).dtype, np.array(actual.values).dtype) + test_case.assertAllEqual(expected.values, actual.values) + + test_case.assertEqual(np.int64, np.array(actual.dense_shape).dtype) + test_case.assertAllEqual(expected.dense_shape, actual.dense_shape) + + +class SequenceCategoricalColumnWithIdentityTest(test.TestCase): + + def test_get_sparse_tensors(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sparse_ids = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=np.array((1, 2, 0), dtype=np.int64), + dense_shape=(2, 2, 1)) + + id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs})) + + self.assertIsNone(id_weight_pair.weight_tensor) + with monitored_session.MonitoredSession() as sess: + _assert_sparse_tensor_value( + self, + expected_sparse_ids, + id_weight_pair.id_tensor.eval(session=sess)) + + def test_get_sparse_tensors_inputs3d(self): + """Tests _get_sparse_tensors when the input is already 3D Tensor.""" + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0, 0), (1, 0, 0), (1, 1, 0)), + values=(1, 2, 0), + dense_shape=(2, 2, 1)) + + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'Column aaa expected ID tensor of rank 2\.\s*' + r'id_tensor shape:\s*\[2 2 1\]'): + id_weight_pair = column._get_sparse_tensors( + _LazyBuilder({'aaa': inputs})) + with monitored_session.MonitoredSession() as sess: + id_weight_pair.id_tensor.eval(session=sess) + + def test_sequence_length(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((0, 0), (1, 0), (1, 1)), + values=(1, 2, 0), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_zeros(self): + column = sfc.sequence_categorical_column_with_identity( + 'aaa', num_buckets=3) + inputs = sparse_tensor.SparseTensorValue( + indices=((1, 0), (3, 0), (3, 1)), + values=(1, 2, 0), + dense_shape=(5, 2)) + expected_sequence_length = [0, 1, 0, 2, 0] + + sequence_length = column._sequence_length(_LazyBuilder({'aaa': inputs})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceEmbeddingColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + # example 2, ids [] + # example 3, ids [1] + indices=((0, 0), (1, 0), (1, 1), (3, 0)), + values=(2, 0, 1, 1), + dense_shape=(4, 2)) + + embedding_dimension = 2 + embedding_values = ( + (1., 2.), # id 0 + (3., 5.), # id 1 + (7., 11.) # id 2 + ) + def _initializer(shape, dtype, partition_info): + self.assertAllEqual((vocabulary_size, embedding_dimension), shape) + self.assertEqual(dtypes.float32, dtype) + self.assertIsNone(partition_info) + return embedding_values + + expected_lookups = [ + # example 0, ids [2] + [[7., 11.], [0., 0.]], + # example 1, ids [0, 1] + [[1., 2.], [3., 5.]], + # example 2, ids [] + [[0., 0.], [0., 0.]], + # example 3, ids [1] + [[3., 5.], [0., 0.]], + ] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=embedding_dimension, + initializer=_initializer) + + embedding_lookup, _ = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertItemsEqual( + ('embedding_weights:0',), tuple([v.name for v in global_vars])) + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess)) + self.assertAllEqual(expected_lookups, embedding_lookup.eval(session=sess)) + + def test_sequence_length(self): + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [2] + # example 1, ids [0, 1] + indices=((0, 0), (1, 0), (1, 1)), + values=(2, 0, 1), + dense_shape=(2, 2)) + expected_sequence_length = [1, 2] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + vocabulary_size = 3 + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, ids [] + # example 1, ids [2] + # example 2, ids [0, 1] + # example 3, ids [] + # example 4, ids [1] + # example 5, ids [] + indices=((1, 0), (2, 0), (2, 1), (4, 0)), + values=(2, 0, 1, 1), + dense_shape=(6, 2)) + expected_sequence_length = [0, 1, 2, 0, 1, 0] + + categorical_column = sfc.sequence_categorical_column_with_identity( + key='aaa', num_buckets=vocabulary_size) + embedding_column = sfc._sequence_embedding_column( + categorical_column, dimension=2) + + _, sequence_length = embedding_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +class SequenceNumericColumnTest(test.TestCase): + + def test_get_sequence_dense_tensor(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_dense_tensor = [ + [[0.], [1.]], + [[10.], [0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa') + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_sequence_dense_tensor_with_shape(self): + """Tests get_sequence_dense_tensor with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_dense_tensor = [ + [[0., 1., 2.], [3., 4., 5.]], + [[10., 11., 12.], [0., 0., 0.]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_get_dense_tensor_multi_dim(self): + """Tests get_sequence_dense_tensor for multi-dim numeric_column.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]] + # example 1, [[[10., 11.], [12., 13.]]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), (0, 6), (0, 7), + (1, 0), (1, 1), (1, 2), (1, 3)), + values=(0., 1., 2., 3., 4., 5., 6., 7., 10., 11., 12., 13.), + dense_shape=(2, 8)) + expected_dense_tensor = [ + [[[0., 1.], [2., 3.]], [[4., 5.], [6., 7.]]], + [[[10., 11.], [12., 13.]], [[0., 0.], [0., 0.]]], + ] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(2, 2)) + + dense_tensor, _ = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_dense_tensor, dense_tensor.eval(session=sess)) + + def test_sequence_length(self): + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0., 1., 2.], [3., 4., 5.]] + # example 1, [[10., 11., 12.]] + indices=((0, 0), (0, 1), (0, 2), (0, 3), (0, 4), (0, 5), + (1, 0), (1, 1), (1, 2)), + values=(0., 1., 2., 3., 4., 5., 10., 11., 12.), + dense_shape=(2, 6)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa', shape=(3,)) + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_shape(self): + """Tests _sequence_length with shape !=(1,).""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [[0.], [1]] + # example 1, [[10.]] + indices=((0, 0), (0, 1), (1, 0)), + values=(0., 1., 10.), + dense_shape=(2, 2)) + expected_sequence_length = [2, 1] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + def test_sequence_length_with_empty_rows(self): + """Tests _sequence_length when some examples do not have ids.""" + sparse_input = sparse_tensor.SparseTensorValue( + # example 0, values [] + # example 1, values [[0.], [1.]] + # example 2, [[2.]] + # example 3, values [] + # example 4, [[3.]] + # example 5, values [] + indices=((1, 0), (1, 1), (2, 0), (4, 0)), + values=(0., 1., 2., 3.), + dense_shape=(6, 2)) + expected_sequence_length = [0, 2, 1, 0, 1, 0] + numeric_column = sfc.sequence_numeric_column('aaa') + + _, sequence_length = numeric_column._get_sequence_dense_tensor( + _LazyBuilder({'aaa': sparse_input})) + + with monitored_session.MonitoredSession() as sess: + self.assertAllEqual( + expected_sequence_length, sequence_length.eval(session=sess)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 9e5f54f0973eae899ca65e4098358107053cb7d4..50868c6d6c943c9f2af162cee7157f596e1f9a69 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -28,7 +28,6 @@ tf_custom_op_py_library( "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", - "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", @@ -64,6 +63,7 @@ tf_custom_op_py_library( "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", "//tensorflow/python:script_ops", + "//tensorflow/python:smart_cond", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -161,23 +161,6 @@ py_test( ], ) -py_test( - name = "accumulate_n_v2_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - "//third_party/py/numpy", - ], -) - cuda_py_test( name = "critical_section_test", size = "medium", @@ -185,31 +168,14 @@ cuda_py_test( additional_deps = [ "//tensorflow/python:client_testlib", ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:platform_test", "//tensorflow/python:resource_variable_ops", - ], -) - -py_test( - name = "accumulate_n_v2_eager_test", - size = "small", - srcs = ["python/ops/accumulate_n_v2_eager_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python/eager:backprop", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:tape", - "//third_party/py/numpy", + "//tensorflow/python:tensor_array_ops", ], ) diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index fb101c36538f72d0665c41a625824eb0d66f48ce..80632500912e92b74b0de5d66277f79dfcba1938 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -85,6 +85,11 @@ See the @{$python/contrib.framework} guide. @@py_func @@sort +@@get_placeholders + +@@smart_cond +@@smart_constant_value + @@CriticalSection @@BoundedTensorSpec @@ -102,10 +107,10 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope - +from tensorflow.python.framework.smart_cond import smart_cond +from tensorflow.python.framework.smart_cond import smart_constant_value from tensorflow.python.framework.tensor_spec import BoundedTensorSpec from tensorflow.python.framework.tensor_spec import TensorSpec - from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] diff --git a/tensorflow/contrib/framework/python/framework/experimental_test.py b/tensorflow/contrib/framework/python/framework/experimental_test.py index 8e54e09e04ee3c0ddbd4fa84cc0912cb70c93e62..cfdc7df7d8fd4c1406bf447a79038ac33b11e047 100644 --- a/tensorflow/contrib/framework/python/framework/experimental_test.py +++ b/tensorflow/contrib/framework/python/framework/experimental_test.py @@ -49,7 +49,6 @@ class ExperimentalTest(test.TestCase): "\nTHIS FUNCTION IS EXPERIMENTAL. It may change or " "be removed at any time, and without warning." "\n" - "\n" "\nArgs:" "\n arg0: Arg 0." "\n arg1: Arg 1." diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index a18ff2320d99726bb355ff6179fc97a070c2fec7..49eec3a3f1a0f357ea3adfade51e71cb0f89942d 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -133,6 +133,18 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, def get_placeholders(graph): """Get placeholders of a graph. + For example: + + ```python + a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a') + a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b') + + tf.contrib.framework.get_placeholders(tf.get_default_graph()) + # Returns: + # [, + # ] + ``` + Args: graph: A tf.Graph. Returns: diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py deleted file mode 100644 index 476528b0dd3df05239d5dc402b466e06dd789985..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_math_ops -from tensorflow.python.ops import math_ops - - - -def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): - """Returns the element-wise sum of a list of tensors. - - Optionally, pass `shape` and `tensor_dtype` for shape and type checking, - otherwise, these are inferred. - - `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not - wait for all of its inputs to be ready before beginning to sum. This can - save memory if inputs are ready at different times, since minimum temporary - storage is proportional to the output size rather than the inputs size. - - Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. - - For example: - - ```python - a = tf.constant([[1, 2], [3, 4]]) - b = tf.constant([[5, 0], [0, 6]]) - tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] - - # Explicitly pass shape and type - tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) - # [[7, 4], - # [6, 14]] - ``` - - Args: - inputs: A list of `Tensor` objects, each with same shape and type. - shape: Shape of elements of `inputs`. - tensor_dtype: The type of `inputs`. - name: A name for the operation (optional). - - Returns: - A `Tensor` of same shape and type as the elements of `inputs`. - - Raises: - ValueError: If `inputs` don't all have same shape and dtype or the shape - cannot be inferred. - """ - _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" - "with the same dtype and shape") - if not inputs or not isinstance(inputs, (list, tuple)): - raise _INPUTS_ERR_MSG - inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) - if not all(isinstance(x, ops.Tensor) for x in inputs): - raise _INPUTS_ERR_MSG - if not all(x.dtype == inputs[0].dtype for x in inputs): - raise _INPUTS_ERR_MSG - if shape is not None: - shape = tensor_shape.as_shape(shape) - else: - shape = tensor_shape.unknown_shape() - for input_tensor in inputs: - if isinstance(input_tensor, ops.Tensor): - shape = shape.merge_with(input_tensor.get_shape()) - - # tensor_dtype is for safety only; operator's output type computed in C++ - if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}" - .format(tensor_dtype, inputs[0].dtype)) - - if len(inputs) == 1 and name is None: - return inputs[0] - elif len(inputs) == 1 and name is not None: - return array_ops.identity(inputs[0], name=name) - elif context.in_eager_mode(): - # TemporaryVariable not currently supported in eager mode; fall back - # onto AddN for now. - # TODO(frreiss) remove this once the lifetime of eager variables gets - # addressed - return math_ops.add_n(inputs, name=name) - else: - return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) - -# The following code should eventually be merged into -# tensorflow/python/ops/math_grad.py -@ops.RegisterGradient("AccumulateNV2") -def _AddNGrad(op, grad): - """Same as gradient for AddN. Copies the gradient to all inputs.""" - # Not broadcasting. - return [grad] * len(op.inputs) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py index 182fec924febb74a23b82b1664d137f033f3b1b4..ab603cc18e12136baea35b10999771c0ada2dd2c 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_ops.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -27,7 +27,11 @@ from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util import nest @@ -38,7 +42,8 @@ CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" class _ExecutionSignature( collections.namedtuple("_ExecutionSignature", - ("op", "exclusive_resource_access"))): + ("op", "handle", + "resources", "exclusive_resource_access"))): """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" pass @@ -112,16 +117,18 @@ class CriticalSection(object): ``` """ - def __init__(self, name=None, critical_section_def=None, import_scope=None): + def __init__(self, name=None, shared_name=None, + critical_section_def=None, import_scope=None): """Creates a critical section.""" if critical_section_def and name is not None: - raise ValueError("critical_section_def and name are mutually exclusive.") + raise ValueError("critical_section_def and shared_name are " + "mutually exclusive.") if critical_section_def: self._init_from_proto(critical_section_def, import_scope=import_scope) else: - self._init_from_args(name) + self._init_from_args(name, shared_name) - def _init_from_proto(self, critical_section_def, import_scope): + def _init_from_proto(self, critical_section_def, import_scope): # pylint: disable=invalid-name raise NotImplementedError("Not yet implemented") # TODO(ebrevdo): Re-enable once CriticalSection is in core. # assert isinstance( @@ -133,18 +140,20 @@ class CriticalSection(object): # critical_section_def.critical_section_name, # import_scope=import_scope)) - def _init_from_args(self, name): + def _init_from_args(self, name, shared_name): # pylint: disable=invalid-name """Initialize the CriticalSection from constructor arguments.""" with ops.name_scope(name, "CriticalSection", []) as name: - with ops.control_dependencies(None): + with ops.init_scope(): # pylint: disable=protected-access - handle_name = ops._name_from_scope_name(name) container = ops.get_default_graph()._container # pylint: enable=protected-access + if shared_name is None: + shared_name = name if container is None: container = "" - self._handle = gen_resource_variable_ops.critical_section_op( - shared_name=handle_name, name=name) + self._handle = gen_resource_variable_ops.mutex_v2( + shared_name=shared_name, container=container, name=name) + if context.in_graph_mode(): ops.add_to_collections(CRITICAL_SECTIONS, self) @@ -183,68 +192,98 @@ class CriticalSection(object): name = kwargs.pop("name", None) exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) - args = nest.map_structure(ops.convert_to_tensor, args) with ops.name_scope(name, "critical_section_execute", []): - fn_op = function.make_defun_op(fn, *args, **kwargs) - flat_dtypes = nest.flatten(fn_op.output_dtypes) - flat_shapes = nest.flatten(fn_op.output_shapes) - all_inputs = nest.flatten(args) + fn_op.captured_inputs - if self._handle in all_inputs: + lock = gen_resource_variable_ops.mutex_lock(self._handle) + + with ops.control_dependencies([lock]): + c_known_ops = set() + c_captured_tensors = set() + + def add_op_internal(op): + c_known_ops.add(op) + for i in op.inputs: + if i.op not in c_known_ops: + c_captured_tensors.add(i) + + c = function.HelperContext(add_op_internal) + with c: + r = fn(*args, **kwargs) + + resource_inputs = set([ + x for x in + list(nest.flatten(args)) + nest.flatten(kwargs.values()) + + list(c_captured_tensors) + if tensor_util.is_tensor(x) and x.dtype == dtypes.resource]) + + if self._handle in resource_inputs: raise ValueError("The function fn attempts to access the " - "CriticalSection in which it would be running. This " - "is illegal and would cause deadlocks. " + "CriticalSection in which it would be running. " + "This is illegal and would cause deadlocks. " "CriticalSection: %s." % self._handle) if context.in_graph_mode(): # Collections and op introspection does not work in eager # mode. This is generally ok; since eager mode (as of # writing) executes sequentially anyway. - all_input_resources = [ - x for x in all_inputs if x.dtype == dtypes.resource] for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): - if sg.op.inputs[0].name == self._handle.name: + sg_handle_name = ops.convert_to_tensor(sg.handle).name + self_handle_name = ops.convert_to_tensor(self._handle).name + if sg_handle_name == self_handle_name: # Other executions in the same critical section are allowed. continue if not (exclusive_resource_access or sg.exclusive_resource_access): # Neither execution requested exclusive access. continue - sg_input_names = [y.name for y in sg.op.inputs[1:]] - for res in all_input_resources: - if res.name in sg_input_names: - raise ValueError( - "This execution would access resource %s; but either this " - "execution (CriticalSection: %s) or Execution '%s' " - "(CriticalSection: %s) requested exclusive resource access " - "of this resource for their critical section. Did you mean " - "to call execute with keyword argument " - "exclusive_resource_access=False?" - % (res.name, - self.name, - sg.op.name, - sg.op.inputs[0].op.name)) - - flat_outputs = gen_resource_variable_ops.execute_in_critical_section( - critical_section=self._handle, - arguments=all_inputs, - f=fn_op, - output_types=flat_dtypes, - output_shapes=flat_shapes) + resource_intersection = resource_inputs.intersection(sg.resources) + if resource_intersection: + raise ValueError( + "This execution would access resources: %s. Either this " + "lock (CriticalSection: %s) or lock '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource. Did you mean to call execute with keyword " + "argument exclusive_resource_access=False?" % + (list(resource_intersection), self._handle.name, + sg.op.name, sg.handle.name)) + + def identity(x): # pylint: disable=invalid-name + if isinstance(x, tensor_array_ops.TensorArray): + return x.identity() + elif isinstance(x, ops.Operation): + return control_flow_ops.group(x) + elif context.in_eager_mode() and x is None: + return None + else: + return array_ops.identity(x) + + r_flat = [identity(x) for x in nest.flatten(r)] + + with ops.control_dependencies(r_flat): + # The identity must run on the same machine as self._handle + with ops.colocate_with(self._handle): + # Do not use array_ops.identity as there are special + # optimizations within TensorFlow which seem to elide it + # even when optimizations are disabled(!). + ensure_lock_exists = gen_resource_variable_ops.consume_mutex_lock( + lock) + + # Make sure that if any element of r is accessed, all of + # them are executed together. + r = nest.pack_sequence_as( + r, control_flow_ops.tuple(nest.flatten(r))) + + with ops.control_dependencies([ensure_lock_exists]): + outputs = nest.map_structure(identity, r) if context.in_graph_mode(): - if isinstance(flat_outputs, ops.Operation): - flat_outputs = [flat_outputs] - op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) - else flat_outputs[0]) signature = _ExecutionSignature( - op=op, + op=lock.op, + handle=self._handle, + resources=list(resource_inputs), exclusive_resource_access=exclusive_resource_access) ops.add_to_collections( CRITICAL_SECTION_EXECUTIONS, signature) - return (flat_outputs[0] - if (len(flat_outputs) == 1 - and isinstance(flat_outputs[0], ops.Operation)) - else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs)) + return outputs # TODO(ebrevdo): Re-enable once CriticalSection is in core. @@ -276,6 +315,7 @@ class CriticalSection(object): # def _execution_to_proto_fn(execution_signature, export_scope=None): # """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. +# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. # Args: # execution_signature: Instance of `_ExecutionSignature`. @@ -298,6 +338,7 @@ class CriticalSection(object): # def _execution_from_proto_fn(op_def, import_scope=None): # """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" +# # TODO(ebrevdo): Update for _ExecutionSignature storing resource list. # assert isinstance( # op_def, critical_section_pb2.CriticalSectionExecutionDef) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py index a416724d3ba1719471d70667e140f9cd2daf86c7..c916592ce1979fe3a79cf28ad4bdac44284cce97 100644 --- a/tensorflow/contrib/framework/python/ops/critical_section_test.py +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -19,12 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework.python.ops import critical_section_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test # TODO(ebrevdo): Re-enable once CriticalSection is in core. @@ -35,7 +33,7 @@ class CriticalSectionTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testCreateCriticalSection(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn(a, b): @@ -45,16 +43,72 @@ class CriticalSectionTest(test.TestCase): with ops.control_dependencies([nv]): return array_ops.identity(c) - num_concurrent = 1000 + num_concurrent = 100 r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] self.evaluate(v.initializer) r_value = self.evaluate(r) self.assertAllClose([2.0 * i for i in range(num_concurrent)], sorted(r_value)) + @test_util.run_in_graph_and_eager_modes() + def testCriticalSectionWithControlFlow(self): + for outer_cond in [False, True]: + for inner_cond in [False, True]: + cs = critical_section_ops.CriticalSection(shared_name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + num_concurrent = 100 + + # pylint: disable=cell-var-from-loop + def fn(a, b): + c = v.read_value() + def true_fn(): + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + return control_flow_ops.cond( + array_ops.identity(inner_cond), true_fn, lambda: c) + + def execute(): + return cs.execute(fn, 1.0, 2.0) + + r = [ + control_flow_ops.cond(array_ops.identity(outer_cond), + execute, + v.read_value) + for _ in range(num_concurrent) + ] + # pylint: enable=cell-var-from-loop + + self.evaluate(v.initializer) + r_value = self.evaluate(r) + if inner_cond and outer_cond: + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + else: + self.assertAllClose([0] * num_concurrent, r_value) + + def testCriticalSectionInParallelDoesntDeadlockOnError(self): + # No eager mode execution of this test because eager does not + # run fn() in parallel, which is where the deadlock could + # potentially occur (in graph mode). + cs = critical_section_ops.CriticalSection(shared_name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn(i): + error = control_flow_ops.Assert((i % 2) == 1, ["Error"]) + with ops.control_dependencies([error]): + return v.read_value() + num_concurrent = 2 + r = [cs.execute(fn, i) for i in range(num_concurrent)] + self.evaluate(v.initializer) + for _ in range(100): + with self.assertRaisesOpError("Error"): + self.evaluate(r) + @test_util.run_in_graph_and_eager_modes() def testCreateCriticalSectionFnReturnsOp(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") v = resource_variable_ops.ResourceVariable(0.0, name="v") def fn_return_op(a, b): @@ -62,7 +116,7 @@ class CriticalSectionTest(test.TestCase): with ops.control_dependencies([c]): nv = v.assign_add(a * b) with ops.control_dependencies([nv]): - return () + return control_flow_ops.no_op() num_concurrent = 100 r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] @@ -71,47 +125,25 @@ class CriticalSectionTest(test.TestCase): final_v = self.evaluate(v) self.assertAllClose(2.0 * num_concurrent, final_v) - def testCreateCriticalSectionRaw(self): - cs = critical_section_ops.CriticalSection(name="cs") - v = resource_variable_ops.ResourceVariable(0.0, name="v") - - @function.Defun(dtypes.float32, dtypes.float32) - def fn(a, b): - c = v.read_value() - with ops.control_dependencies([c]): - nv = v.assign_add(a * b) - with ops.control_dependencies([nv]): - return array_ops.identity(c) - - def execute(fn, *args): - output_args = fn.definition.signature.output_arg - return resource_variable_ops.execute_in_critical_section( - critical_section=cs._handle, - arguments=list(args) + fn.captured_inputs, - f=fn, - output_types=[out.type for out in output_args], - output_shapes=[tensor_shape.TensorShape(None) for _ in output_args]) - - num_concurrent = 1000 - r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] - self.evaluate(v.initializer) - r_value = self.evaluate(r) - self.assertAllClose([2.0 * i for i in range(num_concurrent)], - sorted(r_value)) - def testCollection(self): - cs = critical_section_ops.CriticalSection(name="cs") + cs = critical_section_ops.CriticalSection(shared_name="cs") self.assertIn( cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) - execute_op = cs.execute(lambda x: x + 1, 1.0).op + execute = cs.execute(lambda x: x + 1, 1.0, name="my_execute") + execute_op = [ + x for x in execute.graph.get_operations() + if "my_execute" in x.name and "MutexLock" in x.type + ][0] self.assertIn( execute_op, [signature.op for signature in ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) - @test_util.run_in_graph_and_eager_modes() def testRecursiveCriticalSectionAccessIsIllegal(self): - cs = critical_section_ops.CriticalSection(name="cs") + # This does not work properly in eager mode. Eager users will + # just hit a deadlock if they do this. But at least it'll be easier + # to debug. + cs = critical_section_ops.CriticalSection(shared_name="cs") def fn(x): return cs.execute(lambda x: x+1, x) with self.assertRaisesRegexp( @@ -167,7 +199,7 @@ class CriticalSectionTest(test.TestCase): # self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name) # def testToProto(self): - # cs = critical_section_ops.CriticalSection(name="cs") + # cs = critical_section_ops.CriticalSection(shared_name="cs") # proto = cs.to_proto() # self.assertEqual(proto.critical_section_name, cs._handle.name) # cs_copy = critical_section_ops.CriticalSection.from_proto(proto) diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 5db34f0f8db93620b8b4a6b71f63b66ac718ee30..0eb0e3cbe20f5804db5476c08167d4e1c9080cfa 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -55,6 +55,7 @@ py_test( name = "train_test", srcs = ["python/train_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":features", ":namedtuples", diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 1e18c699ba93b5f524341c65d0a2db84556b65a2..61dc8646ddc10605561ae6b19e90f4739c346608 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -181,7 +181,8 @@ class ClassifierMetricsTest(test.TestCase): batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) pool = _run_with_mock( - classifier_metrics.run_inception, img, + classifier_metrics.run_inception, + img, output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) @@ -195,9 +196,12 @@ class ClassifierMetricsTest(test.TestCase): batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) logits, pool = _run_with_mock( - classifier_metrics.run_inception, img, - output_tensor=[classifier_metrics.INCEPTION_OUTPUT, - classifier_metrics.INCEPTION_FINAL_POOL]) + classifier_metrics.run_inception, + img, + output_tensor=[ + classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL + ]) self.assertTrue(isinstance(logits, ops.Tensor)) self.assertTrue(isinstance(pool, ops.Tensor)) @@ -209,8 +213,10 @@ class ClassifierMetricsTest(test.TestCase): def test_inception_score_graph(self): """Test `inception_score` graph construction.""" - score = _run_with_mock(classifier_metrics.inception_score, - array_ops.zeros([6, 299, 299, 3]), num_batches=3) + score = _run_with_mock( + classifier_metrics.inception_score, + array_ops.zeros([6, 299, 299, 3]), + num_batches=3) self.assertTrue(isinstance(score, ops.Tensor)) score.shape.assert_has_rank(0) @@ -248,12 +254,14 @@ class ClassifierMetricsTest(test.TestCase): array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q) with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - p, array_ops.zeros([8, 10], dtype=dtypes.int32), q) + classifier_metrics._kl_divergence(p, + array_ops.zeros( + [8, 10], dtype=dtypes.int32), q) with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - p, p_logits, array_ops.zeros([10], dtype=dtypes.int32)) + classifier_metrics._kl_divergence(p, p_logits, + array_ops.zeros( + [10], dtype=dtypes.int32)) with self.assertRaisesRegexp(ValueError, 'must have rank 2'): classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q) @@ -266,8 +274,9 @@ class ClassifierMetricsTest(test.TestCase): def test_inception_score_value(self): """Test that `inception_score` gives the correct value.""" - logits = np.array([np.array([1, 2] * 500 + [4]), - np.array([4, 5] * 500 + [6])]) + logits = np.array( + [np.array([1, 2] * 500 + [4]), + np.array([4, 5] * 500 + [6])]) unused_image = array_ops.zeros([2, 299, 299, 3]) incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) @@ -285,9 +294,11 @@ class ClassifierMetricsTest(test.TestCase): test_pool_real_a = np.float32(np.random.randn(512, 256)) test_pool_gen_a = np.float32(np.random.randn(512, 256)) - fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance, - test_pool_real_a, test_pool_gen_a, - classifier_fn=lambda x: x) + fid_op = _run_with_mock( + classifier_metrics.frechet_classifier_distance, + test_pool_real_a, + test_pool_gen_a, + classifier_fn=lambda x: x) with self.test_session() as sess: actual_fid = sess.run(fid_op) @@ -296,6 +307,33 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(expected_fid, actual_fid, 0.0001) + def test_frechet_classifier_distance_covariance(self): + """Test that `frechet_classifier_distance` takes covariance into account.""" + np.random.seed(0) + + # Make num_examples > num_features to ensure scipy's sqrtm function + # doesn't return a complex matrix. + test_pool_reals, test_pool_gens = [], [] + for i in range(1, 11, 2): + test_pool_reals.append(np.float32(np.random.randn(2048, 256) * i)) + test_pool_gens.append(np.float32(np.random.randn(2048, 256) * i)) + + fid_ops = [] + for i in range(len(test_pool_reals)): + fid_ops.append(_run_with_mock( + classifier_metrics.frechet_classifier_distance, + test_pool_reals[i], + test_pool_gens[i], + classifier_fn=lambda x: x)) + + fids = [] + with self.test_session() as sess: + for fid_op in fid_ops: + fids.append(sess.run(fid_op)) + + # Check that the FIDs increase monotonically. + self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:]))) + def test_trace_sqrt_product_value(self): """Test that `trace_sqrt_product` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index 7956db43348c0cc0f3d372e92a2e343f5aa62013..45eb108586bed07434ac29595164745eac6054c1 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -90,8 +90,7 @@ class SummariesTest(test.TestCase): self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, - True) + self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 23a3b60cc0055917bfc5243b0ebdbaea7b61edb9..39588b7219ebac1cc4855532be3fcc38e6381134 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -305,6 +305,7 @@ def wasserstein_gradient_penalty( discriminator_fn, discriminator_scope, epsilon=1e-10, + target=1.0, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -324,6 +325,8 @@ def wasserstein_gradient_penalty( discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. + target: Optional Python number or `Tensor` indicating the target value of + gradient norm. Defaults to 1.0. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -374,7 +377,7 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) + penalties = math_ops.square(slopes / target - 1.0) penalty = losses.compute_weighted_loss( penalties, weights, scope=scope, loss_collection=loss_collection, reduction=reduction) diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 56ac45554da3633149a61155a416fa7cb6cff553..dbaa624ae9d6a5a5949db692e52c0c1deb18b8df 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): + """Test loss value with non default gradient norm target.""" + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + target=2.0) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run( + loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(1.0, loss, 5) + def test_reuses_scope(self): """Test that gradient penalty reuses discriminator scope.""" num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 5d0ac93aec7869bb1d9b8a174ba50d4bec2c2826..776eb11ecb1624544d24611d8fe6ca19768b8313 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -460,6 +460,7 @@ def gan_loss( # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, @@ -481,6 +482,9 @@ def gan_loss( small positive value used by the gradient penalty function for numerical stability. Note some applications will need to increase this value to avoid NaNs. + gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python + number or `Tensor` indicating the target value of gradient norm. See the + CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -539,7 +543,10 @@ def gan_loss( # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): info_loss = tfgan_losses.mutual_information_penalty( diff --git a/tensorflow/contrib/image/kernels/bipartite_match_op.cc b/tensorflow/contrib/image/kernels/bipartite_match_op.cc index 7d207c388b159c4ad0f25032811e97b153fd50d6..726adb07775e3243fdc96a7f1a00dbb0304d3dd9 100644 --- a/tensorflow/contrib/image/kernels/bipartite_match_op.cc +++ b/tensorflow/contrib/image/kernels/bipartite_match_op.cc @@ -85,7 +85,7 @@ class BipartiteMatchOp : public OpKernel { context->allocate_output(1, TensorShape({num_input_columns}), &column_to_row_match_indices)); - typename TTypes::ConstTensor distance_mat = + TTypes::ConstTensor distance_mat = input_distance_mat.shaped( {num_input_rows, num_input_columns}); diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index bfdb69ad02caaa57827e0ae6b3c9fc0d0ed03754..b12f7be76907dc206667eb8ee0c750f3b8db57fc 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -90,49 +90,51 @@ class EstimatorTest(test.TestCase): def testEstimatorInitManualRegistration(self): with self._graph.as_default(): # We should be able to build an estimator for only the registered vars. - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, + self.layer_collection) # Check that we throw an error if we try to build an estimator for vars # that were not manually registered. with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, + estimator.FisherEstimator(lambda: 0.2, [self.weights, self.bias], 0.1, self.layer_collection) # Check that we throw an error if we don't include registered variables, # i.e. self.weights with self.assertRaises(ValueError): - estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + estimator.FisherEstimator(lambda: 0.2, [], 0.1, self.layer_collection) @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) def testVariableWrongNumberOfUses(self, mock_uses): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, + self.layer_collection) def testInvalidEstimationMode(self): with self.assertRaises(ValueError): - estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, - "not_a_real_mode") + estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, + self.layer_collection, "not_a_real_mode") def testModeListCorrect(self): with self._graph.as_default(): - est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + est = estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, self.layer_collection) self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) def testAllModesBuild(self): for mode in _ALL_ESTIMATION_MODES: with self._graph.as_default(): - estimator.FisherEstimator([self.weights], 0.1, 0.2, + estimator.FisherEstimator(lambda: 0.2, [self.weights], 0.1, self.layer_collection, mode) def test_cov_update_thunks(self): """Ensures covariance update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimator( + damping_fn=lambda: 0.2, variables=[self.weights], layer_collection=self.layer_collection, - cov_ema_decay=0.0, - damping=0.0) + cov_ema_decay=0.0) # Construct an op that executes one covariance update per step. global_step = training_util.get_or_create_global_step() @@ -176,10 +178,10 @@ class EstimatorTest(test.TestCase): """Ensures inverse update ops run once per global_step.""" with self._graph.as_default(), self.test_session() as sess: fisher_estimator = estimator.FisherEstimator( + damping_fn=lambda: 0.2, variables=[self.weights], layer_collection=self.layer_collection, - cov_ema_decay=0.0, - damping=0.0) + cov_ema_decay=0.0) # Construct op that updates one inverse per global step. global_step = training_util.get_or_create_global_step() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 82accd57f0c37d140238f1884fce956654d14227..fb4b3a241c1e9fd82e7bf630fd57295917048fbd 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import state_ops @@ -236,10 +237,10 @@ class NaiveDiagonalFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) -class FullyConnectedDiagonalFB(test.TestCase): +class FullyConnectedDiagonalFBTest(test.TestCase): def setUp(self): - super(FullyConnectedDiagonalFB, self).setUp() + super(FullyConnectedDiagonalFBTest, self).setUp() self.batch_size = 4 self.input_size = 6 @@ -375,6 +376,65 @@ class FullyConnectedDiagonalFB(test.TestCase): return multiply_result, multiply_inverse_result +class EmbeddingKFACFBTest(test.TestCase): + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_minibatch(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(([grads],), damping) + + def testMultiplyInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + # Create a Fisher Block. + vocab_size = 5 + block = fb.EmbeddingKFACFB(lc.LayerCollection(), vocab_size) + + # Add some examples. + inputs = array_ops.constant([[0, 1], [1, 2], [2, 3]]) + outputs = array_ops.constant([[0.], [1.], [2.]]) + block.register_additional_minibatch(inputs, outputs) + + # Instantiate factor's variables. Ensure it doesn't fail. + grads = outputs**2. + damping = array_ops.constant(0.) + block.instantiate_factors(([grads],), damping) + + # Create a sparse update. + indices = array_ops.constant([1, 3, 4]) + values = array_ops.constant([[1.], [1.], [1.]]) + sparse_vector = ops.IndexedSlices( + values, indices, dense_shape=[vocab_size, 1]) + dense_vector = array_ops.reshape([0., 1., 0., 1., 1.], [vocab_size, 1]) + + # Compare Fisher-vector product against explicit result. + result = block.multiply_inverse(sparse_vector) + expected_result = linalg_ops.matrix_solve(block.full_fisher_block(), + dense_vector) + + sess.run(tf_variables.global_variables_initializer()) + self.assertAlmostEqual( + sess.run(expected_result[1]), sess.run(result.values[0])) + self.assertAlmostEqual( + sess.run(expected_result[3]), sess.run(result.values[1])) + self.assertAlmostEqual( + sess.run(expected_result[4]), sess.run(result.values[2])) + + class FullyConnectedKFACBasicFBTest(test.TestCase): def testFullyConnectedKFACBasicFBInit(self): diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 753378d9f4a0d8762bafbee2ec27d6c71783dda1..66e18974abfadaad5d7a20b40d0b1352bfda67ee 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -89,6 +89,21 @@ class FisherFactorTestingDummy(ff.FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + return NotImplementedError + + def left_multiply(self, x, damping): + return NotImplementedError + + def right_multiply(self, x, damping): + return NotImplementedError + + def left_multiply_inverse(self, x, damping): + return NotImplementedError + + def right_multiply_inverse(self, x, damping): + return NotImplementedError + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -379,7 +394,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + self.assertEqual([6, 1], factor.get_cov_var().get_shape().as_list()) def testNaiveDiagonalFactorInitFloat64(self): with tf_ops.Graph().as_default(): @@ -387,7 +402,7 @@ class NaiveDiagonalFactorTest(test.TestCase): random_seed.set_random_seed(200) tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.NaiveDiagonalFactor((tensor,), 32) - cov = factor.get_cov() + cov = factor.get_cov_var() self.assertEqual(cov.dtype, dtype) self.assertEqual([6, 1], cov.get_shape().as_list()) @@ -402,6 +417,29 @@ class NaiveDiagonalFactorTest(test.TestCase): self.assertAllClose([[0.75], [1.5]], new_cov) +class EmbeddingInputKroneckerFactorTest(test.TestCase): + + def testInitialization(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + cov = factor.get_cov_var() + self.assertEqual(cov.shape.as_list(), [vocab_size]) + + def testCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(): + input_ids = array_ops.constant([[0], [1], [4]]) + vocab_size = 5 + factor = ff.EmbeddingInputKroneckerFactor((input_ids,), vocab_size) + cov_update_op = factor.make_covariance_update_op(0.0) + + with self.test_session() as sess: + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(cov_update_op) + self.assertAllClose(np.array([1., 1., 0., 0., 1.]) / 3., new_cov) + + class FullyConnectedKroneckerFactorTest(test.TestCase): def _testFullyConnectedKroneckerFactorInit(self, diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index ee6549b109399766579b6ea18a987ae2c8275983..c26230c2a82ae9529ab13b523b9ec287d17debaf 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -144,10 +144,13 @@ py_library( ":fisher_estimator", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index a7b1f9d35c931fc44408be804479e758f28f7110..a7e268c48ae326a4d8fa5fe4a4ed15b8b83a0ed9 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -83,9 +83,9 @@ class FisherEstimator(object): """ def __init__(self, + damping_fn, variables, cov_ema_decay, - damping, layer_collection, estimation_mode="gradients", colocate_gradients_with_ops=True, @@ -94,16 +94,12 @@ class FisherEstimator(object): """Create a FisherEstimator object. Args: + damping_fn: Function, accepts no arguments and returns damping value. variables: A list of the variables for which to estimate the Fisher. This must match the variables registered in layer_collection (if it is not None). cov_ema_decay: The decay factor used when calculating the covariance estimate moving averages. - damping: The damping factor used to stabilize training due to errors in - the local approximation with the Fisher information matrix, and to - regularize the update direction by making it closer to the gradient. - (Higher damping means the update looks more like a standard gradient - update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher blocks, kronecker factors, and losses associated with the graph. @@ -135,10 +131,9 @@ class FisherEstimator(object): Raises: ValueError: If no losses have been registered with layer_collection. """ - + self._damping_fn = damping_fn self._cov_ema_decay = cov_ema_decay self._variables = variables - self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection self._layers.create_subgraph() @@ -182,7 +177,7 @@ class FisherEstimator(object): @property def damping(self): - return self._damping + return self._damping_fn() def _apply_transformation(self, vecs_and_vars, transform): """Applies an block-wise transformation to the corresponding vectors. diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index 0d2fa706f5853570bb8c04a9b9ac3378e2f2386e..cf38d28b43836dced8babe2ffa7853b1c4b1b369 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -92,10 +92,22 @@ def compute_pi_tracenorm(left_cov, right_cov): Returns: The computed scalar constant pi for these Kronecker Factors (as a Tensor). """ + + def _trace(cov): + if len(cov.shape) == 1: + # Diagonal matrix. + return math_ops.reduce_sum(cov) + elif len(cov.shape) == 2: + # Full matrix. + return math_ops.trace(cov) + else: + raise ValueError( + "What's the trace of a Tensor of rank %d?" % len(cov.shape)) + # Instead of dividing by the dim of the norm, we multiply by the dim of the # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] - right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + left_norm = _trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = _trace(right_cov) * left_cov.shape.as_list()[0] return math_ops.sqrt(left_norm / right_norm) @@ -201,15 +213,15 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_damped_inverse(self._damping) - out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) + vector_flat = utils.tensors_to_column(vector) + out_flat = self._factor.left_multiply_inverse( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = ( - math_ops.matmul(self._factor.get_cov(), vector_flat) + - self._damping * vector_flat) + out_flat = self._factor.left_multiply( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): @@ -265,16 +277,20 @@ class NaiveDiagonalFB(FisherBlock): def multiply_inverse(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat / (self._factor.get_cov() + self._damping) + print("vector_flat: %s" % vector_flat) + out_flat = self._factor.left_multiply_inverse( + vector_flat, self._damping) + print("out_flat: %s" % out_flat) return utils.column_to_tensors(vector, out_flat) def multiply(self, vector): vector_flat = utils.tensors_to_column(vector) - out_flat = vector_flat * (self._factor.get_cov() + self._damping) + out_flat = self._factor.left_multiply( + vector_flat, self._damping) return utils.column_to_tensors(vector, out_flat) def full_fisher_block(self): - return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,))) + return self._factor.get_cov() def tensors_to_compute_grads(self): return self._params @@ -356,8 +372,9 @@ class FullyConnectedDiagonalFB(FisherBlock): Tensor of the same shape, corresponding to the inverse Fisher-vector product. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply_inverse( + reshaped_vec, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): @@ -372,8 +389,9 @@ class FullyConnectedDiagonalFB(FisherBlock): Returns: Tensor of the same shape, corresponding to the Fisher-vector product. """ - reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_vec = utils.layer_params_to_mat2d(vector) + reshaped_out = self._factor.left_multiply( + reshaped_vec, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): @@ -468,12 +486,14 @@ class ConvDiagonalFB(FisherBlock): def multiply_inverse(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + reshaped_out = self._factor.left_multiply_inverse( + reshaped_vect, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): reshaped_vect = utils.layer_params_to_mat2d(vector) - reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + reshaped_out = self._factor.left_multiply( + reshaped_vect, self._damping) return utils.mat2d_to_layer_params(vector, reshaped_out) def tensors_to_compute_grads(self): @@ -533,28 +553,24 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_damped_inverse( - self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = math_ops.matmul(left_factor_inv, - math_ops.matmul(reshaped_vector, - right_factor_inv)) + reshaped_out = self._output_factor.right_multiply_inverse( + reshaped_vector, + self._output_damping) + reshaped_out = self._input_factor.left_multiply_inverse( + reshaped_out, self._input_damping) if self._renorm_coeff != 1.0: reshaped_out /= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) return utils.mat2d_to_layer_params(vector, reshaped_out) def multiply(self, vector): - left_factor = self._input_factor.get_cov() - right_factor = self._output_factor.get_cov() reshaped_vector = utils.layer_params_to_mat2d(vector) - reshaped_out = ( - math_ops.matmul(reshaped_vector, right_factor) + - self._output_damping * reshaped_vector) - reshaped_out = ( - math_ops.matmul(left_factor, reshaped_out) + - self._input_damping * reshaped_out) + reshaped_out = self._output_factor.right_multiply( + reshaped_vector, + self._output_damping) + reshaped_out = self._input_factor.left_multiply( + reshaped_out, self._input_damping) if self._renorm_coeff != 1.0: reshaped_out *= math_ops.cast( self._renorm_coeff, dtype=reshaped_out.dtype) @@ -574,6 +590,74 @@ class KroneckerProductFB(FisherBlock): right_factor) +class EmbeddingKFACFB(KroneckerProductFB): + """K-FAC FisherBlock for embedding layers. + + This FisherBlock is similar to EmbeddingKFACFB, except that its + input factor is approximated by a diagonal matrix. In the case that each + example references exactly one embedding, this approximation is exact. + + Does not support bias parameters. + """ + + def __init__(self, layer_collection, vocab_size): + """Creates a EmbeddingKFACFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + vocab_size: int. Size of vocabulary for this embedding layer. + """ + self._inputs = [] + self._outputs = [] + self._vocab_size = vocab_size + + super(EmbeddingKFACFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.EmbeddingInputKroneckerFactor, # + ((inputs,), self._vocab_size)) + self._output_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + (grads_list,)) + self._register_damped_input_and_output_inverses(damping) + + def tensors_to_compute_grads(self): + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + + class FullyConnectedKFACBasicFB(KroneckerProductFB): """K-FAC FisherBlock for fully-connected (dense) layers. diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index ac396309206fe09af65c2b70840a513fb25b579b..c04cf727fa958160d61c7a3638ec65f6c93c2f24 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -29,6 +29,7 @@ _allowed_symbols = [ 'NaiveDiagonalFB', 'FullyConnectedDiagonalFB', 'KroneckerProductFB', + 'EmbeddingKFACFB', 'FullyConnectedKFACBasicFB', 'ConvKFCBasicFB', 'ConvDiagonalFB', @@ -36,7 +37,9 @@ _allowed_symbols = [ 'compute_pi_tracenorm', 'compute_pi_adjusted_damping', 'num_conv_locations', - 'normalize_damping' + 'normalize_damping', + 'LEFT_MULTIPLY', + 'RIGHT_MULTIPLY', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index bcba18ae147c6ceca50bc9a2a17e01fc201d88c1..603d8b8b210279ee6d8f1de0ce10869fde23f4d9 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -25,13 +25,13 @@ import numpy as np import six from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -112,54 +112,6 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def extract_image_patches(image, ksizes, strides, padding, name=None): - """Extracts image patches for an N-dimensional convolution. - - This function is a compatibility wrapper over tf.extract_image_patches(), as - ExtractImagePatches isn't yet implemented in XLA. - - Args: - image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images. - All dimensions except 'batch' must be defined. - ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each - dimension. - strides: [stride_x, stride_y, ...]. Spatial stride for filter in each - dimension. - padding: str. "VALID" or "SAME". - name: str or None. name of Op. - - Returns: - result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels]. - Contains image patches to which conv kernel would be applied for each - output location. [out_x, out_y, ...] depends on padding. - """ - if not utils.on_tpu(): - return array_ops.extract_image_patches( - image, - ksizes=([1] + list(ksizes) + [1]), - strides=([1] + list(strides) + [1]), - rates=[1, 1, 1, 1], - padding=padding, - name=name) - - with tf_ops.name_scope(name, "extract_image_patches", - [image, ksizes, strides, padding]): - batch = image.shape.as_list()[0] - in_channels = image.shape.as_list()[-1] - - # Map each input feature to a location in the output. - out_channels = np.prod(ksizes) * in_channels - filters = linalg_ops.eye(out_channels), - filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels]) - - result = nn.convolution(image, filters, padding, strides=strides) - out_spatial = result.shape.as_list()[1:-1] - result = array_ops.reshape( - result, [batch or -1] + out_spatial + ksizes + [in_channels]) - - return result - - def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. @@ -259,12 +211,21 @@ def scalar_or_tensor_to_string(val): class FisherFactor(object): """Base class for objects modeling factors of approximate Fisher blocks. - Note that for blocks that aren't based on approximations, a 'factor' can - be the entire block itself, as is the case for the diagonal and full - representations. + A FisherFactor represents part of an approximate Fisher Information matrix. + For example, one approximation to the Fisher uses the Kronecker product of two + FisherFactors A and B, F = kron(A, B). FisherFactors are composed with + FisherBlocks to construct a block-diagonal approximation to the full Fisher. + + FisherFactors are backed by a single, non-trainable variable that is updated + by running FisherFactor.make_covariance_update_op(). The shape and type of + this variable is implementation specific. - Subclasses must implement the _compute_new_cov method, and the _var_scope - and _cov_shape properties. + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. + + Subclasses must implement the _compute_new_cov() method, and the _var_scope + and _cov_shape properties. """ def __init__(self): @@ -272,16 +233,21 @@ class FisherFactor(object): @abc.abstractproperty def _var_scope(self): + """Variable scope for this FisherFactor instance. + + Returns: + string that unique identifies this FisherFactor instance. + """ pass @abc.abstractproperty def _cov_shape(self): - """The shape of the cov matrix.""" + """The shape of the variable backing this FisherFactor.""" pass @abc.abstractproperty def _num_sources(self): - """The number of things to sum over when computing cov. + """The number of things to sum over when updating covariance variable. The default make_covariance_update_op function will call _compute_new_cov with indices ranging from 0 to _num_sources-1. The typical situation is @@ -293,10 +259,12 @@ class FisherFactor(object): @abc.abstractproperty def _dtype(self): + """dtype for variable backing this factor.""" pass @property def _cov_initializer(self): + """Function for initializing covariance variable.""" return covariance_initializer def instantiate_covariance(self): @@ -311,6 +279,15 @@ class FisherFactor(object): @abc.abstractmethod def _compute_new_cov(self, idx=0): + """Computes minibatch-estimated covariance for a single source. + + Args: + idx: int in [0, self._num_sources). Which source to use when estimating + covariance. + + Returns: + Tensor of same shape as self.get_cov_var(). + """ pass def make_covariance_update_op(self, ema_decay): @@ -343,14 +320,101 @@ class FisherFactor(object): """Create and return update ops corresponding to registered computations.""" pass + @abc.abstractmethod def get_cov(self): + """Get full covariance matrix. + + Returns: + Tensor of shape [n, n]. Represents all parameter-parameter correlations + captured by this FisherFactor. + """ + pass + + def get_cov_var(self): + """Get variable backing this FisherFactor. + + May or may not be the same as self.get_cov() + + Returns: + Variable of shape self._cov_shape. + """ return self._cov + @abc.abstractmethod + def left_multiply(self, x, damping): + """Multiplies 'x' by the damped covariance of this factor. + + Let C be the covariance matrix this factor represents, and + D = C + damping * I be its damped variant. This method calculates + matmul(D, vec(x)). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def right_multiply(self, x, damping): + """Multiplies 'x' by the damped covariance of this factor. + + Let C be the covariance matrix this factor represents, and + D = C + damping * I be its damped variant. This method calculates + matmul(vec(x), D). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def left_multiply_inverse(self, x, damping): + """Multiplies 'x' by damped inverse of this factor. + + Let C be the covariance matrix this factor represents and + E = inv(C + damping * I) be its damped inverse. This method calculates + matmul(E, vec(x)). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + + @abc.abstractmethod + def right_multiply_inverse(self, x, damping): + """Multiplies 'x' by damped inverse of this factor. + + Let C be the covariance matrix this factor represents and + E = inv(C + damping * I) be its damped inverse. This method calculates + matmul(vec(x), E). + + Args: + x: Tensor. Represents a single vector. Shape depends on implementation. + damping: 0-D Tensor. Damping to add to C's diagonal. + + Returns: + Tensor of same shape as 'x'. + """ + pass + class InverseProvidingFactor(FisherFactor): - """Base class for FisherFactors that maintain inverses, powers, etc of _cov. + """Base class for FisherFactors that maintain inverses explicitly. - Assumes that the _cov property is a square PSD matrix. + This class explicitly calculates and stores inverses of covariance matrices + provided by the underlying FisherFactor implementation. It is assumed that + vectors can be represented as 2-D matrices. Subclasses must implement the _compute_new_cov method, and the _var_scope and _cov_shape properties. @@ -485,6 +549,61 @@ class InverseProvidingFactor(FisherFactor): def reset_eigendecomp(self): self._eigendecomp = None + def get_cov(self): + # Variable contains full covariance matrix. + return self.get_cov_var() + + def left_multiply(self, x, damping): + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping * array_ops.eye(n) + + if isinstance(x, tf_ops.IndexedSlices): + raise NotImplementedError( + "Left-multiply not yet supported for IndexedSlices.") + + if len(x.shape) != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(damped_cov, x) + + def right_multiply(self, x, damping): + n = self.get_cov().shape[0] + damped_cov = self.get_cov() + damping * array_ops.eye(n) + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_sparse_dense(x, damped_cov) + + if len(x.shape) != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(x, damped_cov) + + def left_multiply_inverse(self, x, damping): + if isinstance(x, tf_ops.IndexedSlices): + raise ValueError("Left-multiply not yet supported for IndexedSlices.") + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(self.get_damped_inverse(damping), x) + + def right_multiply_inverse(self, x, damping): + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_sparse_dense(x, self.get_damped_inverse(damping)) + + if x.shape.ndims != 2: + raise ValueError( + "InverseProvidingFactors apply to matrix-shaped vectors. Found: %s." + % (x,)) + + return math_ops.matmul(x, self.get_damped_inverse(damping)) + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -530,7 +649,11 @@ class FullFactor(InverseProvidingFactor): class DiagonalFactor(FisherFactor): - """A base class for FisherFactors that use diagonal approximations.""" + """A base class for FisherFactors that use diagonal approximations. + + A DiagonalFactor's covariance variable can be of any shape, but must contain + exactly one entry per parameter. + """ def __init__(self): super(DiagonalFactor, self).__init__() @@ -542,6 +665,45 @@ class DiagonalFactor(FisherFactor): def make_inverse_update_ops(self): return [] + def get_cov(self): + # self.get_cov() could be any shape, but it must have one entry per + # parameter. Flatten it into a vector. + cov_diag_vec = array_ops.reshape(self.get_cov_var(), [-1]) + return array_ops.diag(cov_diag_vec) + + def left_multiply(self, x, damping): + damped_cov = self.get_cov_var() + damping + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(damped_cov, [-1]), x) + + if x.shape != damped_cov.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, damped_cov)) + + return damped_cov * x + + def right_multiply(self, x, damping): + raise NotImplementedError("Only left-multiply is currently supported.") + + def left_multiply_inverse(self, x, damping): + inverse = 1. / (self.get_cov_var() + damping) + + if isinstance(x, tf_ops.IndexedSlices): + return utils.matmul_diag_sparse(array_ops.reshape(inverse, [-1]), x) + + if x.shape != inverse.shape: + raise ValueError("x (%s) and cov (%s) must have same shape." % + (x, inverse)) + + return inverse * x + + def right_multiply_inverse(self, x, damping): + raise NotImplementedError("Only left-multiply is currently supported.") + + def register_damped_inverse(self, damping): + # DiagonalFactors don't keep explicit inverses. + pass + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -553,6 +715,14 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, batch_size): + """Initializes NaiveDiagonalFactor instance. + + Args: + params_grads: Sequence of Tensors, each with same shape as parameters this + FisherFactor corresponds to. For example, the gradient of the loss with + respect to parameters. + batch_size: int or 0-D Tensor. Size + """ self._params_grads = tuple(utils.ensure_sequence(params_grad) for params_grad in params_grads) self._batch_size = batch_size @@ -567,7 +737,7 @@ class NaiveDiagonalFactor(DiagonalFactor): def _cov_shape(self): size = sum(param_grad.shape.num_elements() for param_grad in self._params_grads[0]) - return (size, 1) + return [size, 1] @property def _num_sources(self): @@ -584,6 +754,84 @@ class NaiveDiagonalFactor(DiagonalFactor): self._batch_size, params_grads_flat.dtype)) +class EmbeddingInputKroneckerFactor(DiagonalFactor): + r"""FisherFactor for input to an embedding layer. + + Given input_ids = [batch_size, input_size] representing indices into an + [vocab_size, embedding_size] embedding matrix, approximate input covariance by + a diagonal matrix, + + Cov(input_ids, input_ids) = + (1/batch_size) sum_{i} diag(n_hot(input[i]) ** 2). + + where n_hot() constructs an n-hot binary vector and diag() constructs a + diagonal matrix of size [vocab_size, vocab_size]. + """ + + def __init__(self, input_ids, vocab_size, dtype=None): + """Instantiate EmbeddingInputKroneckerFactor. + + Args: + input_ids: Tuple of Tensors of shape [batch_size, input_size] and dtype + int32. Indices into embedding matrix. + vocab_size: int or 0-D Tensor. Maximum value for entries in 'input_ids'. + dtype: dtype for covariance statistics. Must be a floating point type. + Defaults to float32. + """ + self._input_ids = input_ids + self._vocab_size = vocab_size + self._cov_dtype = dtype or dtypes.float32 + + super(EmbeddingInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diag_embedding/" + scope_string_from_params(self._input_ids) + + @property + def _cov_shape(self): + return [self._vocab_size] + + @property + def _num_sources(self): + return len(self._input_ids) + + @property + def _dtype(self): + return self._cov_dtype + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._input_ids): + input_ids = self._input_ids[idx] + if len(input_ids.shape) > 2: + raise ValueError( + "Input to embeddings must have rank <= 2. Found rank %d." % len( + input_ids.shape)) + + batch_size = array_ops.shape(input_ids)[0] + + # Transform indices into one-hot vectors. + # + # TODO(b/72714822): There must be a faster way to construct the diagonal + # covariance matrix! This operation is O(batch_size * vocab_size), where + # it should be O(batch_size * input_size). + flat_input_ids = array_ops.reshape(input_ids, [-1]) + one_hots = array_ops.one_hot(flat_input_ids, + self._vocab_size) # [?, vocab_size] + + # Take average across examples. Note that, because all entries have + # magnitude zero or one, there's no need to square the entries. + # + # TODO(b/72714822): Support for SparseTensor, other kinds of aggregation + # within an example such as average. + # + # TODO(b/72714822): Support for partitioned embeddings. + new_cov = math_ops.reduce_sum(one_hots, axis=0) # [vocab_size] + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + class FullyConnectedDiagonalFactor(DiagonalFactor): r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. @@ -623,8 +871,9 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): @property def _cov_shape(self): - return [self._inputs.shape[1] + self._has_bias, - self._outputs_grads[0].shape[1]] + input_size = self._inputs.shape[1] + self._has_bias + output_size = self._outputs_grads[0].shape[1] + return [input_size, output_size] @property def _num_sources(self): @@ -717,10 +966,11 @@ class ConvDiagonalFactor(DiagonalFactor): # TODO(b/64144716): there is potential here for a big savings in terms # of memory use. - patches = extract_image_patches( + patches = array_ops.extract_image_patches( self._inputs, - ksizes=[filter_height, filter_width], - strides=self._strides[1:-1], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], padding=self._padding) if self._has_bias: @@ -864,10 +1114,11 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): # TODO(b/64144716): there is potential here for a big savings in terms of # memory use. - patches = extract_image_patches( + patches = array_ops.extract_image_patches( self._inputs, - ksizes=[filter_height, filter_width], - strides=self._strides[1:-1], + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], padding=self._padding) flatten_size = (filter_height * filter_width * in_channels) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index ad93919149c287b1932dd2b6bd772c0dab26192d..2d8e378a932c16d48360bc4b15ff4f3239c0ed1f 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -24,26 +24,15 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ - "inverse_initializer", - "covariance_initializer", - "diagonal_covariance_initializer", - "scope_string_from_params", - "scope_string_from_name", - "scalar_or_tensor_to_string", - "FisherFactor", - "InverseProvidingFactor", - "FullFactor", - "DiagonalFactor", - "NaiveDiagonalFactor", - "FullyConnectedDiagonalFactor", - "FullyConnectedKroneckerFactor", - "ConvInputKroneckerFactor", - "ConvOutputKroneckerFactor", - "ConvDiagonalFactor", - "set_global_constants", - "maybe_colocate_with", - "compute_cov", - "append_homog" + "inverse_initializer", "covariance_initializer", + "diagonal_covariance_initializer", "scope_string_from_params", + "scope_string_from_name", "scalar_or_tensor_to_string", "FisherFactor", + "InverseProvidingFactor", "FullFactor", "DiagonalFactor", + "NaiveDiagonalFactor", "EmbeddingInputKroneckerFactor", + "FullyConnectedDiagonalFactor", "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", "set_global_constants", "maybe_colocate_with", + "compute_cov", "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 8d450f04f379701e46a18b2e34bbbd6fcfcce2bb..ce9005b9ce99a4efa5f2821c56e199dd2086482e 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -143,6 +143,7 @@ class LayerCollection(object): self._loss_dict = {} # {str: LossFunction} self._subgraph = None self._default_generic_approximation = APPROX_FULL_NAME + self._default_embedding_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME self._default_fully_connected_multi_approximation = ( @@ -178,6 +179,17 @@ class LayerCollection(object): """ return self._linked_parameters + @property + def default_embedding_approximation(self): + return self._default_embedding_approximation + + def set_default_embedding_approximation(self, value): + if value != APPROX_KRONECKER_NAME: + raise ValueError( + "{} is not a valid approximation for embedding variables.".format( + value)) + self._default_embedding_approximation = value + @property def default_generic_approximation(self): return self._default_generic_approximation @@ -417,6 +429,46 @@ class LayerCollection(object): else: return None + def register_embedding(self, + params, + inputs, + outputs, + approx=None, + reuse=VARIABLE_SCOPE): + """Registers a fully connnected layer. + + Args: + params: Embedding matrix of shape [vocab_size, embedding_size]. + inputs: Tensor of shape [batch_size, input_size] and dtype int32. Indices + into embedding matrix. + outputs: Tensor of shape [batch_size, output_size]. Outputs + produced by layer. + approx: str. Must be "kron". + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If "VARIABLE_SCOPE", use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_embedding_approximation + + if approx != APPROX_KRONECKER_NAME: + raise ValueError("Bad value {} for approx.".format(approx)) + + if isinstance(params, (tuple, list)): + raise ValueError("Bias not supported.") + + vocab_size = int(params.shape[0]) + block = self.register_block( + params, fb.EmbeddingKFACFB(self, vocab_size), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + def register_fully_connected(self, params, inputs, diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index 1974b07acfc879dc4bc844db9af88fd1043d6698..5d456bcb79ff00cedc1aaa7244cc8722d21f6e98 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -23,11 +23,14 @@ from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products from tensorflow.contrib.kfac.python.ops import estimator as est # pylint enable=long-line +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import gradient_descent @@ -61,6 +64,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): damping: The damping factor used to stabilize training due to errors in the local approximation with the Fisher information matrix, and to regularize the update direction by making it closer to the gradient. + If damping is adapted during training then this value is used for + initializing damping varaible. (Higher damping means the update looks more like a standard gradient update - see Tikhonov regularization.) layer_collection: The layer collection object, which holds the fisher @@ -105,10 +110,31 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): if variables is None: variables = tf_variables.trainable_variables() + # The below paramaters are required only if damping needs to be adapated. + # These parameters can be set by calling + # set_damping_adaptation_params() explicitly. + self._damping_adaptation_decay = 0.95 + self._damping_adaptation_interval = 5 + # Check section 6.5 KFAC paper. omega(1) = pow(damping decay, interval) + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._adapt_damping = False + self._min_damping = 1e-5 + self._prev_train_batch = None + self._is_chief = False + self._loss_fn = None + self._damping_constant = damping + self._damping = None + self._rho = None + self._prev_loss = None + self._q_model_change = None + self._update_damping_op = None + + self._layers = layer_collection self._fisher_est = est.FisherEstimator( + lambda: self.damping, variables, cov_ema_decay, - damping, layer_collection, estimation_mode=estimation_mode, colocate_gradients_with_ops=colocate_gradients_with_ops, @@ -139,6 +165,60 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): super(KfacOptimizer, self).__init__(learning_rate, name=name) + def set_damping_adaptation_params(self, + is_chief, + prev_train_batch, + loss_fn, + min_damping=1e-5, + damping_adaptation_decay=0.99, + damping_adaptation_interval=5): + """Sets parameters required to adapt damping during training. + + When called, enables damping adaptation according to the Levenberg-Marquardt + style rule described in Section 6.5 of "Optimizing Neural Networks with + Kronecker-factored Approximate Curvature". + + Args: + is_chief: `Boolean`, `True` if the worker is chief. + prev_train_batch: Training data used to minimize loss in the previous + step. This will be used to evaluate loss by calling + `loss_fn(prev_train_batch)`. + loss_fn: `function` that takes as input training data tensor and returns + a scalar loss. + min_damping: `float`(Optional), Minimum value the damping parameter + can take. Default value 1e-5. + damping_adaptation_decay: `float`(Optional), The `damping` parameter is + multipled by the `damping_adaptation_decay` every + `damping_adaptation_interval` number of iterations. Default value 0.99. + damping_adaptation_interval: `int`(Optional), Number of steps in between + updating the `damping` parameter. Default value 5. + + Raises: + ValueError: If `set_damping_adaptation_params` is already called and the + the `adapt_damping` is `True`. + """ + if self._adapt_damping: + raise ValueError("Damping adaptation parameters already set.") + with variable_scope.variable_scope(self.get_name()): + self._adapt_damping = True + self._is_chief = is_chief + self._prev_train_batch = prev_train_batch + self._loss_fn = loss_fn + self._damping_adaptation_decay = damping_adaptation_decay + self._damping_adaptation_interval = damping_adaptation_interval + self._omega = ( + self._damping_adaptation_decay**self._damping_adaptation_interval) + self._min_damping = min_damping + + self._rho = variable_scope.get_variable( + "rho", shape=(), dtype=dtypes.float32, trainable=False) # LM ratio. + self._prev_loss = variable_scope.get_variable( + "prev_loss", shape=(), dtype=dtypes.float32, trainable=False) + self._q_model_change = variable_scope.get_variable( + "q_model_change", shape=(), dtype=dtypes.float32, trainable=False) + self._damping = variable_scope.get_variable( + "damping", initializer=self._damping_constant, trainable=False) + @property def cov_update_thunks(self): return self._fisher_est.cov_update_thunks @@ -169,14 +249,34 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): @property def damping(self): - return self._fisher_est.damping + if self._damping: + return self._damping + else: + return self._damping_constant + + @property + def damping_adaptation_interval(self): + return self._damping_adaptation_interval def minimize(self, *args, **kwargs): kwargs["var_list"] = kwargs.get("var_list") or self.variables if set(kwargs["var_list"]) != set(self.variables): raise ValueError("var_list doesn't match with set of Fisher-estimating " "variables.") - return super(KfacOptimizer, self).minimize(*args, **kwargs) + if self._adapt_damping and self._is_chief: + global_step = kwargs.get("global_step", None) + if not global_step: + raise KeyError("global_step needs to be passed to optimizer.minimize " + "if damping parameter is adapted.") + update_damping_op = self._update_damping(self._prev_train_batch, + global_step) + with ops.control_dependencies([update_damping_op]): + loss = args[0] + loss_assign_op = state_ops.assign(self._prev_loss, loss) + train_op = super(KfacOptimizer, self).minimize(*args, **kwargs) + return control_flow_ops.group(loss_assign_op, train_op) + else: + return super(KfacOptimizer, self).minimize(*args, **kwargs) def compute_gradients(self, *args, **kwargs): # args[1] could be our var_list @@ -296,6 +396,20 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] + def _compute_prev_updates(self, variables): + """Computes previous updates as negative velocities scaled by learning rate. + + Args: + variables: List of variables in the graph that the update will be + applied to. + + Returns: + List of previous updates applied to the `variables`. + """ + return list( + -1 * self._learning_rate * self._zeros_slot(var, "velocity", self._name) + for var in variables) + def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, variables): """Compute optimal update hyperparameters from the quadratic model. @@ -374,9 +488,9 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], [_inner_product_list(grads, prev_updates)]]) - sol = _two_by_two_solve(m, c) - alpha = -sol[0] - mu = -sol[1] + sol = -1. * _two_by_two_solve(m, c) + alpha = sol[0] + mu = sol[1] qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) return alpha, mu, qmodel_change @@ -404,6 +518,52 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): return control_flow_ops.cond( math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) + def _assign_q_model_change(self, q_model_change): + """Assigns `q_model_change` to `self._q_model_change` if damping is adapted. + + Note only the chief worker does the assignment. + + Args: + q_model_change: Scalar tensor of type `float32`. + + Returns: + If `adapt_damping` is `True` then returns an assign op, Otherwise returns + a no_op(). + """ + if self._adapt_damping and self._is_chief: + q_model_assign_op = state_ops.assign(self._q_model_change, q_model_change) + else: + q_model_assign_op = control_flow_ops.no_op() + return q_model_assign_op + + def _compute_qmodel_hyperparams_wrapper(self, grads_and_vars, + precon_grads_and_vars): + """Wrapper function for `self._compute_qmodel_hyperparams`. + + Constructs a list of preconditioned gradients and variables. Also creates a + op to asssign the computed q model change to `self._q_model_change`. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradients, variable) + pairs. + + Returns: + (alpha, mu, q_model_assign_op), where alpha and mu are chosen to optimize + the quadratic model, `q_model_assign_op` assigns the computed q model + change to `self._q_model_change`. + """ + precon_grads = list( + precon_grad for (precon_grad, _) in precon_grads_and_vars) + grads = list(grad for (grad, _) in grads_and_vars) + variables = list(var for (_, var) in grads_and_vars) + prev_updates = self._compute_prev_updates(variables) + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, q_model_change = self._compute_qmodel_hyperparams( + precon_grads, prev_updates, grads, variables) + + return alpha, mu, self._assign_q_model_change(q_model_change) + def _compute_update_steps(self, grads_and_vars): """Computes the update steps for the variables given the gradients. @@ -411,8 +571,10 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): grads_and_vars: List of (gradient, variable) pairs. Returns: - An 'Operation that computes the update steps for the given variables. + A list of tuple (assign_op ,var) where `assign_op` assigns the update + steps to `var`. """ + if self._momentum_type == "regular": # Compute "preconditioned" gradient. precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) @@ -423,8 +585,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): precon_grads_and_vars) # Update the velocity with this and return it as the step. - return self._update_velocities(precon_grads_and_vars, self._momentum) - + if self._adapt_damping and self._is_chief: + _, _, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities(precon_grads_and_vars, self._momentum) + else: + return self._update_velocities(precon_grads_and_vars, self._momentum) elif self._momentum_type == "adam": # Update velocity. velocities_and_vars = self._update_velocities(grads_and_vars, @@ -436,23 +603,13 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Compute "preconditioned" gradient. precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) - # Extract out singleton lists from the tuple-lists - precon_grads = list( - precon_grad for (precon_grad, _) in precon_grads_and_vars) - grads = list(grad for (grad, _) in grads_and_vars) - variables = list(var for (_, var) in grads_and_vars) - # previous updates are the negative velocities (up to scaling by LR) - prev_updates = list( - -self._zeros_slot(var, "velocity", self._name) for var in variables) - # Compute optimal velocity update parameters according to quadratic model - alpha, mu, _ = self._compute_qmodel_hyperparams( - precon_grads, prev_updates, grads, variables) + alpha, mu, q_model_assign_op = self._compute_qmodel_hyperparams_wrapper( + grads_and_vars, precon_grads_and_vars) - # Update the velocity with precon_grads according to these params - # and return it as the step. - return self._update_velocities( - precon_grads_and_vars, mu, vec_coeff=-alpha) + with ops.control_dependencies([q_model_assign_op]): + return self._update_velocities( + precon_grads_and_vars, mu, vec_coeff=-alpha) def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): """Updates the velocities of the variables with the given vectors. @@ -482,6 +639,51 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): # Go through variable and update its associated part of the velocity vector. return [_update_velocity(vec, var) for vec, var in vecs_and_vars] + # TODO(b/73448937): Move all update damping code to a separate class/function. + def _update_damping(self, prev_batch, global_step): + """Adapts damping parameter. Check KFAC (Section 6.5) for the details. + + The damping parameter is updated according to the Levenberg-Marquardt rule + every `self._damping_adaptation_interval` iterations. + + Args: + prev_batch: Tensor or tuple of tensors which can be passed to + `self._loss_fn` to evaluate loss. + global_step: `Variable` which keeps track of number of times the training + variables have been updated. + Returns: + A `tf.cond` op which updates the damping parameter. + """ + def compute_damping(): + """"Adapts damping parameter based on "reduction ratio". + + Reduction ratio captures how closely the quadratic approximation to the + loss function approximates the actual loss within a trust region. The + damping update tries to make the damping as small as possible while + maintaining the property that the quadratic model remains a good local + approximation to the loss function. + + Returns: + An Op to assign newly computed damping value to `self._damping`. + """ + prev_batch_loss = self._loss_fn(prev_batch) + with ops.control_dependencies([prev_batch_loss]): + rho_assign = self._rho.assign( + (prev_batch_loss - self._prev_loss) / self._q_model_change) + with ops.control_dependencies([rho_assign]): + new_damping = control_flow_ops.case( + [(self._rho < 0.25, lambda: self.damping / self._omega), + (self._rho > 0.75, lambda: self.damping * self._omega)], + lambda: self.damping) + with ops.control_dependencies([new_damping]): + new_damping_min = math_ops.maximum(new_damping, self._min_damping) + return control_flow_ops.group(self._damping.assign(new_damping_min)) + + return control_flow_ops.cond( + math_ops.equal( + math_ops.mod(global_step + 1, self._damping_adaptation_interval), + 0), compute_damping, control_flow_ops.no_op) + def _inner_product_list(list1, list2): return math_ops.add_n( diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index e89508fa46b6e2ce278e5373e6c9d17203ad1ef2..88e6fb20e8f97528aea2a92752d79344c27bbf24 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -144,7 +144,9 @@ def layer_params_to_mat2d(vector): [-1, w_part.shape.as_list()[-1]]) return array_ops.concat( (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) - else: + elif isinstance(vector, ops.IndexedSlices): + return vector + else: # Tensor or Tensor-like. return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) @@ -163,6 +165,11 @@ def mat2d_to_layer_params(vector_template, mat2d): if isinstance(vector_template, (tuple, list)): w_part, b_part = mat2d[:-1], mat2d[-1] return array_ops.reshape(w_part, vector_template[0].shape), b_part + elif isinstance(vector_template, ops.IndexedSlices): + if not isinstance(mat2d, ops.IndexedSlices): + raise TypeError( + "If vector_template is an IndexedSlices, so should mat2d.") + return mat2d else: return array_ops.reshape(mat2d, vector_template.shape) @@ -234,19 +241,22 @@ class SubGraph(object): # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() - self._recurse_add(outputs) + self._iter_add(outputs) - def _recurse_add(self, nodes): - """Recursively adds all of nodes' ancestors.""" - for node in nodes: - if node in self._members: - continue - self._members.add(node) + def _iter_add(self, root): + """Iteratively adds all of nodes' ancestors using depth first search.""" + stack = [root] + while stack: + nodes = stack.pop() + for node in nodes: + if node in self._members: + continue + self._members.add(node) - if isinstance(node, ops.Tensor): - self._recurse_add((node.op,)) - elif isinstance(node, ops.Operation): - self._recurse_add(node.inputs) + if isinstance(node, ops.Tensor): + stack.append((node.op,)) + elif isinstance(node, ops.Operation): + stack.append(node.inputs) def is_member(self, node): """Check if 'node' is in this subgraph.""" @@ -420,5 +430,57 @@ def batch_execute(global_step, thunks, batch_size, name=None): return result +def matmul_sparse_dense(A, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is sparse, B is dense. + + Args: + A: tf.IndexedSlices with dense shape [m, n]. + B: tf.Tensor with shape [n, k]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A doesn't represent a matrix. + ValueError: If B is not rank-2. + """ + with ops.name_scope(name, "matmul_sparse_dense", [A, B]): + if A.indices.shape.ndims != 1 or A.values.shape.ndims != 2: + raise ValueError("A must represent a matrix. Found: %s." % A) + if B.shape.ndims != 2: + raise ValueError("B must be a matrix.") + new_values = math_ops.matmul(A.values, B) + return ops.IndexedSlices( + new_values, + A.indices, + dense_shape=array_ops.stack([A.dense_shape[0], new_values.shape[1]])) + + +def matmul_diag_sparse(A_diag, B, name=None): # pylint: disable=invalid-name + """Computes matmul(A, B) where A is a diagonal matrix, B is sparse. + + Args: + A_diag: diagonal entries of matrix A of shape [m, m]. + B: tf.IndexedSlices. Represents matrix of shape [m, n]. + name: str. Name of op. + + Returns: + tf.IndexedSlices resulting from matmul(A, B). + + Raises: + ValueError: If A_diag is not rank-1. + ValueError: If B doesn't represent a matrix. + """ + with ops.name_scope(name, "matmul_diag_sparse", [A_diag, B]): + A_diag = ops.convert_to_tensor(A_diag) + if A_diag.shape.ndims != 1: + raise ValueError("A_diag must be a rank-1 Tensor.") + if B.indices.shape.ndims != 1 or B.values.shape.ndims != 2: + raise ValueError("B must represent a matrix. Found: %s." % B) + a = array_ops.gather(A_diag, B.indices) + a = array_ops.reshape(a, list(a.shape) + [1] * (B.values.shape.ndims - 1)) + return ops.IndexedSlices(a * B.values, B.indices, dense_shape=B.dense_shape) + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index fe8e39c212c2c3381f9aa6fdb9fdf423ff958481..8e424a794691484fdea7d8481677aa641c433d4c 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -40,6 +40,8 @@ _allowed_symbols = [ "fwd_gradients", "ensure_sequence", "batch_execute", + "matmul_sparse_dense", + "matmul_diag_sparse", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core.py b/tensorflow/contrib/labeled_tensor/python/ops/core.py index abc18aa123bb4d40b54d22ec03257c5350118d13..0c6bba758b429a8c4112bc6abb2fae542b5dfc14 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core.py @@ -361,6 +361,10 @@ class LabeledTensor(object): def dtype(self): return self._tensor.dtype + @property + def shape(self): + return self._tensor.shape + @property def name(self): return self._tensor.name diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py index e70b4923749d89aba1bd0187857d762305daeb07..e378db56afb1d4f9463d2c9b0f1fa4c0feea8fb0 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py @@ -244,6 +244,9 @@ class LabeledTensorTest(test_util.Base): def test_dtype(self): self.assertEqual(self.lt.dtype, self.lt.tensor.dtype) + def test_shape(self): + self.assertEqual(self.lt.shape, self.lt.tensor.shape) + def test_get_shape(self): self.assertEqual(self.lt.get_shape(), self.lt.tensor.get_shape()) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/ops.py b/tensorflow/contrib/labeled_tensor/python/ops/ops.py index c957b41a49b292225e547ce17b0c5a247810325a..3ba1026383ef146adb32197ae41b5c251155bf46 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/ops.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/ops.py @@ -951,7 +951,7 @@ def define_reduce_op(op_name, reduce_fn): intermediate_axes.append(axis) reduce_op = reduce_fn( - labeled_tensor.tensor, reduction_dimensions, keep_dims=True) + labeled_tensor.tensor, reduction_dimensions, keepdims=True) reduce_lt = core.LabeledTensor(reduce_op, intermediate_axes) return squeeze(reduce_lt, axes_to_squeeze, name=scope) diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index b7d34d6435789e54403926a342481971e854b449..9ccb589d698ad83c9654f5523ccdcb35b031b3da 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -154,6 +154,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +from tensorflow.python.util import nest # Imports the core `InputLayer` symbol in contrib during development. @@ -554,28 +555,70 @@ def sparse_column_with_integerized_feature(column_name, class _SparseColumnHashed(_SparseColumn): """See `sparse_column_with_hash_bucket`.""" + def __new__(cls, + column_name, + is_integerized=False, + bucket_size=None, + lookup_config=None, + combiner="sum", + dtype=dtypes.string, + hash_keys=None): + if hash_keys is not None: + if not isinstance(hash_keys, list) or not hash_keys: + raise ValueError("hash_keys must be a non-empty list.") + if (any([not isinstance(key_pair, list) for key_pair in hash_keys]) or + any([len(key_pair) != 2 for key_pair in hash_keys]) or + any([not isinstance(key, int) for key in nest.flatten(hash_keys)])): + raise ValueError( + "Each element of hash_keys must be a pair of integers.") + obj = super(_SparseColumnHashed, cls).__new__( + cls, + column_name, + is_integerized=is_integerized, + bucket_size=bucket_size, + lookup_config=lookup_config, + combiner=combiner, + dtype=dtype) + obj.hash_keys = hash_keys + return obj + def _do_transform(self, input_tensor): if self.dtype.is_integer: sparse_values = string_ops.as_string(input_tensor.values) else: sparse_values = input_tensor.values - sparse_id_values = string_ops.string_to_hash_bucket_fast( - sparse_values, self.bucket_size, name="lookup") - return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values, - input_tensor.dense_shape) + if self.hash_keys: + result = [] + for key in self.hash_keys: + sparse_id_values = string_ops.string_to_hash_bucket_strong( + sparse_values, self.bucket_size, key) + result.append( + sparse_tensor_py.SparseTensor(input_tensor.indices, + sparse_id_values, + input_tensor.dense_shape)) + return sparse_ops.sparse_concat(axis=1, sp_inputs=result, name="lookup") + else: + sparse_id_values = string_ops.string_to_hash_bucket_fast( + sparse_values, self.bucket_size, name="lookup") + return sparse_tensor_py.SparseTensor( + input_tensor.indices, sparse_id_values, input_tensor.dense_shape) def sparse_column_with_hash_bucket(column_name, hash_bucket_size, combiner="sum", - dtype=dtypes.string): + dtype=dtypes.string, + hash_keys=None): """Creates a _SparseColumn with hashed bucket configuration. Use this when your sparse features are in string or integer format, but you don't have a vocab file that maps each value to an integer ID. output_id = Hash(input_feature_string) % bucket_size + When hash_keys is set, multiple integer IDs would be created with each key + pair in the `hash_keys`. This is useful to reduce the collision of hashed ids. + Args: column_name: A string defining sparse column name. hash_bucket_size: An int that is > 1. The number of buckets. @@ -588,6 +631,9 @@ def sparse_column_with_hash_bucket(column_name, * "sqrtn": do l2 normalization on features in the column For more information: `tf.embedding_lookup_sparse`. dtype: The type of features. Only string and integer types are supported. + hash_keys: The hash keys to use. It is a list of lists of two uint64s. If + None, simple and fast hashing algorithm is used. Otherwise, multiple + strong hash ids would be produced with each two unit64s in this argument. Returns: A _SparseColumn with hashed bucket configuration @@ -600,7 +646,8 @@ def sparse_column_with_hash_bucket(column_name, column_name, bucket_size=hash_bucket_size, combiner=combiner, - dtype=dtype) + dtype=dtype, + hash_keys=hash_keys) class _SparseColumnKeys(_SparseColumn): diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index fc8f153fe3abdc83aca5abfa9a4bb5f5d5531480..1de9ab705655db9863d9c7d2630f24283c83d44d 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -329,6 +329,55 @@ class FeatureColumnTest(test.TestCase): self.assertEqual(one_hot.sparse_id_column.name, "ids_weighted_by_weights") self.assertEqual(one_hot.length, 3) + def testOneHotColumnWithSparseColumnWithHashKeys(self): + input_values = ["marlo", "unknown", "omar"] + inputs = constant_op.constant(input_values) + hash_keys = [[10, 20], [20, 30]] + hash_column = fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=10, hash_keys=hash_keys) + columns_to_tensors = {} + columns_to_tensors["ids"] = inputs + hash_column.insert_transformed_feature(columns_to_tensors) + self.assertEqual(len(columns_to_tensors), 2) + self.assertTrue(hash_column in columns_to_tensors) + + one_hot_column = fc.one_hot_column(hash_column) + one_hot_output = one_hot_column._to_dnn_input_layer( + columns_to_tensors[hash_column]) + + expected = np.array([[0., 1., 0., 0., 0., 0., 0., 1., 0., + 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0., 1.], + [1., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]) + with self.test_session() as sess: + one_hot_value = sess.run(one_hot_output) + self.assertTrue(np.array_equal(one_hot_value, expected)) + + def testSparseColumnWithHashKeysWithUnexpectedHashKeys(self): + with self.assertRaisesRegexp(ValueError, + "hash_keys must be a non-empty list."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[]) + + with self.assertRaisesRegexp(ValueError, + "hash_keys must be a non-empty list."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=1) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[1, 2]) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=["key"]) + + with self.assertRaisesRegexp( + ValueError, "Each element of hash_keys must be a pair of integers."): + fc.sparse_column_with_hash_bucket( + column_name="ids", hash_bucket_size=100, hash_keys=[[1, 2.0]]) + def testMissingValueInOneHotColumnForWeightedSparseColumn(self): # Github issue 12583 ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index c42eab4efcd480c4ac262448465a8b744fcc27ec..80cbe68870808328b387e2044fe236af5a5e39f8 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -51,7 +51,6 @@ from tensorflow.python.ops import standard_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as tf_variables from tensorflow.python.training import moving_averages -from tensorflow.python.layers.maxout import maxout # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. @@ -518,8 +517,8 @@ def batch_norm(inputs, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. @@ -2187,8 +2186,10 @@ def layer_norm(inputs, @add_arg_scope -def images_to_sequence(inputs, data_format=DATA_FORMAT_NHWC, - outputs_collections=None, scope=None): +def images_to_sequence(inputs, + data_format=DATA_FORMAT_NHWC, + outputs_collections=None, + scope=None): """Convert a batch of images into a batch of sequences. Args: inputs: a (num_images, height, width, depth) tensor @@ -2694,8 +2695,11 @@ def separable_convolution2d( @add_arg_scope -def sequence_to_images(inputs, height, output_data_format='channels_last', - outputs_collections=None, scope=None): +def sequence_to_images(inputs, + height, + output_data_format='channels_last', + outputs_collections=None, + scope=None): """Convert a batch of sequences into a batch of images. Args: inputs: (num_steps, num_batches, depth) sequence tensor @@ -2936,6 +2940,53 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None): return math_ops.div(inputs, array_ops.tile(lengths, multiples)) +@add_arg_scope +def maxout(inputs, num_units, axis=-1, scope=None): + """Adds a maxout op from https://arxiv.org/abs/1302.4389 + + "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron + Courville, + Yoshua Bengio + + Usually the operation is performed in the filter/channel dimension. This can + also be + used after fully-connected layers to reduce number of features. + + Arguments: + inputs: Tensor input + num_units: Specifies how many features will remain after maxout + in the `axis` dimension (usually channel). + This must be multiple of number of `axis`. + axis: The dimension where max pooling will be performed. Default is the + last dimension. + scope: Optional scope for variable_scope. + + Returns: + A `Tensor` representing the results of the pooling operation. + + Raises: + ValueError: if num_units is not multiple of number of features. + """ + with variable_scope.variable_scope(scope, 'MaxOut', [inputs]): + inputs = ops.convert_to_tensor(inputs) + shape = inputs.get_shape().as_list() + num_channels = shape[axis] + if num_channels % num_units: + raise ValueError('number of features({}) is not ' + 'a multiple of num_units({})'.format( + num_channels, num_units)) + shape[axis] = -1 + shape += [num_channels // num_units] + + # Dealing with batches with arbitrary sizes + for i in range(len(shape)): + if shape[i] is None: + shape[i] = array_ops.shape(inputs)[i] + outputs = math_ops.reduce_max( + array_ops.reshape(inputs, shape), -1, keepdims=False) + return outputs + + def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): """Project into the Poincare ball with norm <= 1.0 - epsilon. diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 0f062adbab3ca9acfb89543b69c7c957bbdf5dd8..997f910a2a97567adbd7ffa3e81a31d2ae0bad7e 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -4135,5 +4135,31 @@ class LegacyFullyConnectedTest(test.TestCase): _layers.legacy_fully_connected(x, 2, activation_fn=nn_ops.softmax) +class MaxOutTest(test.TestCase): + + def test_simple(self): + inputs = random_ops.random_uniform((64, 10, 36), seed=1) + graph = _layers.maxout(inputs, num_units=3) + self.assertEqual(graph.get_shape().as_list(), [64, 10, 3]) + + def test_fully_connected(self): + inputs = random_ops.random_uniform((64, 50), seed=1) + graph = _layers.fully_connected(inputs, 50) + graph = _layers.maxout(graph, num_units=10) + self.assertEqual(graph.get_shape().as_list(), [64, 10]) + + def test_nchw(self): + inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) + graph = _layers.conv2d(inputs, 10, 3, padding='SAME') + graph = _layers.maxout(graph, num_units=1) + self.assertEqual(graph.get_shape().as_list(), [10, 100, 100, 1]) + + def test_invalid_shape(self): + inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) + graph = _layers.conv2d(inputs, 3, 10) + with self.assertRaisesRegexp(ValueError, 'number of features'): + graph = _layers.maxout(graph, num_units=2) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index cdceea6fee5bdb5aeb6537ea55d25ccf107def4c..69d927e1b3001d14dd1af2f890b07c1a57ab2cfc 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -41,7 +41,7 @@ OPTIMIZER_CLS_NAMES = { "Adagrad": train.AdagradOptimizer, "Adam": train.AdamOptimizer, "Ftrl": train.FtrlOptimizer, - "Momentum": lambda lr: train.MomentumOptimizer(lr, momentum=0.9), + "Momentum": lambda learning_rate: train.MomentumOptimizer(learning_rate, momentum=0.9), # pylint: disable=line-too-long "RMSProp": train.RMSPropOptimizer, "SGD": train.GradientDescentOptimizer, } diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 1ea25bd1a5685eb6f840e621b5739029a660aa0f..a4461a20e54c289886f1a1beb255de12fc054afe 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -61,7 +61,8 @@ class OptimizersTest(test.TestCase): optimizers = [ "SGD", gradient_descent.GradientDescentOptimizer, gradient_descent.GradientDescentOptimizer(learning_rate=0.1), - lambda lr: gradient_descent.GradientDescentOptimizer(learning_rate=lr) + lambda lr: gradient_descent.GradientDescentOptimizer(learning_rate=lr), + "Momentum" ] for optimizer in optimizers: with ops.Graph().as_default() as g: diff --git a/tensorflow/contrib/learn/README.md b/tensorflow/contrib/learn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d516bffc5e0327a3400068b35de5503e5a925a54 --- /dev/null +++ b/tensorflow/contrib/learn/README.md @@ -0,0 +1,143 @@ +EVERYTHING IN THIS DIRECTORY IS DEPRECATED. + +Using functions or classes will result in warnings. + +Instructions for converting to current alternatives are included in the +warnings. A high-level overview is below. + +## Canned Estimators + +Many canned estimators (subclasses of `Estimator`) have equivalents in core: +`DNNClassifier`, `DNNRegressor`, `DNNEstimator`, `LinearClassifier`, +`LinearRegressor`, `DNNLinearCombinedClassifier` and +`DNNLinearCombinedRegressor`. They are exposed under `tf.estimator`. +`DNNEstimator`, `LinearEstimator` and `DNNLinearCombinedEstimator` +are exposed under `tf.contrib.estimator`. + +To migrate to the new api, users need to take the following steps: + +* Replace `tf.contrib.learn` with `tf.estimator`. +* If you subclass any of the estimators, stop doing that. You should be able to + write a factory method that returns a canned estimator instead. If this is not + possible (if you override methods from the canned estimator), consider writing + a custom estimator instead. See `tf.estimator.Estimator`. +* Set `loss_reduction=tf.losses.Reduction.SUM_OVER_BATCH_SIZE` to preserve loss + reduction as the average over batch. +* Some optimizer-related arguments are no longer passed in the estimator + constructor. Instead, we provide methods that perform the same job by wrapping + an optimizer. Specifically: + * `gradient_clip_norm`: Use `tf.contrib.estimator.clip_gradients_by_norm` + * `embedding_lr_multipliers`: Not supported. + Other arguments: + * `input_layer_min_slice_size`: Replaced by `input_layer_partitioner` + * `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. + * `feature_engineering_fn`: Not supported. You can call your + `feature_engineering_fn` inside your input_fn: + ```python + def new_input_fn(): + features, labels = old_input_fn() + return feature_engineering_fn(features, labels) + ``` +* Use `tf.reshape` to reshape labels in your `input_fn`. `tf.estimator` + classifiers and regressors expect labels as a 2D Tensor of shape + `[batch_size, 1]`, or `[batch_size, n_labels]`. In contrast, + `tf.contrib.learn` classifiers and regressors supported labels with shape + `[batch_size]`. +* If you pass custom metrics from the `evaluate()` method call, use + `tf.contrib.estimator.add_metrics`. +* Replace your `serving_input_fn` with a `serving_input_receiver_fn`. + Note this should be entirely distinct from your training `input_fn`, so if you + previously had one `input_fn` with different "modes", you should now factor + that apart. Where the former returned either a simple `(features, labels)` + tuple or `InputFnOps`, you should now return a `ServingInputReceiver`. + If you were generating your `serving_input_fn` using the + `build_parsing_serving_input_fn` helper, you can simply drop in the + replacement `build_parsing_serving_input_receiver_fn`. + +Some remaining estimators/classes: + +* `DynamicRnnEstimator`: Consider a custom `model_fn`. +* `KMeansClustering`: Use `tf.contrib.factorization.KMeansClustering`. +* `LogisticRegressor`: Not supported. Instead, use `binary_classification_head` + with a custom `model_fn`, or with `DNNEstimator`. +* `StateSavingRnnEstimator`: Consider a custom `model_fn`. +* SVM: Consider a custom `model_fn`. +* `LinearComposableModel` and `DNNComposableModel`: Not supported. + Consider `tf.contrib.estimator.DNNEstimator`, or write a custom model_fn. +* `MetricSpec`: Deprecated. For adding custom metrics to canned Estimators, use + `tf.contrib.estimator.add_metrics`. + +## Estimator +`tf.contrib.learn.Estimator` is migrated to `tf.estimator.Estimator`. + +To migrate, users need to take the following steps: + +* Replace `tf.contrib.learn.Estimator` with `tf.estimator.Estimator`. +* If you pass a `config` argument to `Estimator`, this must be + `tf.estimator.RunConfig`. You may need to edit your code accordingly. +* Edit your `model_fn` to return `tf.estimator.EstimatorSpec`. Refer to + `EstimatorSpec` for documentation of specific fields. +* If your `model_fn` uses the `mode` argument, use `tf.estimator.ModeKeys`. + +Some related classes: +* `Evaluable`, `Trainable`: Not supported, merged into `tf.estimator.Estimator`. +* ExportStrategy: Replaced by `tf.estimator.Exporter`. + +## Head/MultiHead +These classes are now supported under `tf.contrib.estimator`, e.g. +`tf.contrib.estimator.multi_class_head` and `tf.contrib.estimator.multi_head`. + +Some differences: + +* `multi_class_head`: If you use `tf.contrib.learn.multi_class_head` with + `n_classes=2`, switch to `tf.contrib.estimator.binary_classification_head`. +* `loss_only_head`: Not supported. +* `poisson_regression_head`: Not supported (yet). +* `binary_svm_head`: Not supported (yet). +* `no_op_train_fn`: Replace it with `tf.no_op`. + +Some arguments are renamed, please refer to documentation. In addition: + +* `loss_fn`: Supported for `multi_label_head`. If you need it for other heads, + please open an issue. +* `metric_class_ids`: Not supported (yet). +* `enable_centered_bias`: Not supported. Dropping this argument is unlikely to + harm your model. +* `label_name`: Not needed in `tf.estimator`. If you don’t use `multi_head`, + drop this argument. If you use `multi_head`, refer to + `tf.contrib.estimator.multi_head` documentation. + +## Experiment Class - Distributed Training Tooling + +Switch to `tf.estimator.train_and_evaluate`. Some differences: + +* Most of the constructor arguments, like `train_input_fn`, `eval_input_fn`, + should be wrapped into `tf.estimator.TrainSpec` and `tf.estimator.EvalSpec`. +* Remove the `experiment_fn`. Instead, create the `Estimator`, + `train_spec` and `eval_spec`, then call `tf.estimator.train_and_evaluate` + directly. +* Inside `tf.estimator.EvalSpec`, the `exporter` field is the replacement + for `export_strategy`. To be precise, `tf.estimator.LatestExporter` is the + replacement for `tf.contrib.learn.make_export_strategy`. If you want to export + only at the end of training use `tf.estimator.FinalExporter`. +* If the `TF_CONFIG` environment variable is constructed manually, please read + the `train_and_evaluate` documentation for the new requirementds (in + particular, the chief node and evaluator node). + +## Others Classes and Functions + +* `tf.contrib.learn.datasets` is deprecated. We are adding ready to use datasets + to tensorflow/models. Many smaller datasets are available from other sources, + such as scikits.learn. Some Python processing may have to be written, but this + is straightforward to implement using the standard modules. +* `tf.contrib.learn.preprocessing`: Deprecated. The python-only preprocessing + functions are not a good fit for TensorFlow. Please use `tf.data`, and + consider tensorflow/transform for more complex use cases. +* `tf.contrib.learn.models`: Not supported, use canned estimators instead. +* `tf.contrib.learn.monitors`: Implement `SessionRunHook` instead. Hook + implementations are in `tf.train`. +* `tf.contrib.learn.learn_io`: Use the methods in `tf.estimator.inputs`, such as + `tf.estimator.inputs.numpy_input_fn`. Some utility functions have no + equivalent, we encourage the use of `tf.data`. + diff --git a/tensorflow/contrib/learn/__init__.py b/tensorflow/contrib/learn/__init__.py index 3698af027e38f1063ad829c26eb179734968f813..79bd73faaf1301a2fc4999b64f88d30542577980 100644 --- a/tensorflow/contrib/learn/__init__.py +++ b/tensorflow/contrib/learn/__init__.py @@ -13,8 +13,11 @@ # limitations under the License. # ============================================================================== -# TODO(ptucker,ipolosukhin): Improve descriptions. -"""High level API for learning. +"""High level API for learning (DEPRECATED). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. See the @{$python/contrib.learn} guide. diff --git a/tensorflow/contrib/learn/python/__init__.py b/tensorflow/contrib/learn/python/__init__.py index bbebd5ab9792cb937219cf937f08c4d4e6e44a92..df23aeb2c433c2b4392f706730f715246ce01cea 100644 --- a/tensorflow/contrib/learn/python/__init__.py +++ b/tensorflow/contrib/learn/python/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/__init__.py b/tensorflow/contrib/learn/python/learn/__init__.py index cdc67c77d5fd1df61016835dc75ba44feb458cf9..76e0e8ac8f19026086959f3b197cfd1a81e65a3e 100644 --- a/tensorflow/contrib/learn/python/learn/__init__.py +++ b/tensorflow/contrib/learn/python/learn/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level API for learning with TensorFlow.""" +"""High level API for learning with TensorFlow (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py index 2284ec46e971731af74f17678fc0d1d3888419e2..fed1c44d1970bf07c808ace817aa9972d7776d88 100644 --- a/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py +++ b/tensorflow/contrib/learn/python/learn/basic_session_run_hooks.py @@ -12,20 +12,47 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Some common SessionRunHook classes.""" +"""Some common SessionRunHook classes (deprected). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.util.deprecation import deprecated_alias # pylint: disable=invalid-name -LoggingTensorHook = basic_session_run_hooks.LoggingTensorHook -StopAtStepHook = basic_session_run_hooks.StopAtStepHook -CheckpointSaverHook = basic_session_run_hooks.CheckpointSaverHook -StepCounterHook = basic_session_run_hooks.StepCounterHook -NanLossDuringTrainingError = basic_session_run_hooks.NanLossDuringTrainingError -NanTensorHook = basic_session_run_hooks.NanTensorHook -SummarySaverHook = basic_session_run_hooks.SummarySaverHook +LoggingTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.LoggingTensorHook', + 'tf.train.LoggingTensorHook', + basic_session_run_hooks.LoggingTensorHook) +StopAtStepHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StopAtStepHook', + 'tf.train.StopAtStepHook', + basic_session_run_hooks.StopAtStepHook) +CheckpointSaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.CheckpointSaverHook', + 'tf.train.CheckpointSaverHook', + basic_session_run_hooks.CheckpointSaverHook) +StepCounterHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.StepCounterHook', + 'tf.train.StepCounterHook', + basic_session_run_hooks.StepCounterHook) +NanLossDuringTrainingError = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanLossDuringTrainingError', + 'tf.train.NanLossDuringTrainingError', + basic_session_run_hooks.NanLossDuringTrainingError) +NanTensorHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.NanTensorHook', + 'tf.train.NanTensorHook', + basic_session_run_hooks.NanTensorHook) +SummarySaverHook = deprecated_alias( + 'tf.contrib.learn.basic_session_run_hooks.SummarySaverHook', + 'tf.train.SummarySaverHook', + basic_session_run_hooks.SummarySaverHook) # pylint: enable=invalid-name diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index 7240b0de149051afa045a8113f9e9b212840c311..3c34712ac859d32f549468345950a93d2ed2aa56 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Dataset utilities and synthetic/reference datasets.""" +"""Dataset utilities and synthetic/reference datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.datasets import mnist from tensorflow.contrib.learn.python.learn.datasets import synthetic from tensorflow.contrib.learn.python.learn.datasets import text_datasets +from tensorflow.python.util.deprecation import deprecated # Export load_iris and load_boston. load_iris = base.load_iris @@ -51,6 +57,7 @@ SYNTHETIC = { } +@deprecated(None, 'Please use tf.data.') def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -73,8 +80,9 @@ def load_dataset(name, size='small', test_with_fake_data=False): return DATASETS[name]() +@deprecated(None, 'Please use tf.data.') def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): - """Creates binary synthetic datasets + """Creates binary synthetic datasets. Args: name: str, name of the dataset to generate diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index ca720ae5ed26e74da12bd6c5a37231b41442f76f..3b5c9b97c08a388e1f35249967b6cab26861f100 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base utilities for loading datasets.""" + +"""Base utilities for loading datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +35,14 @@ import numpy as np from six.moves import urllib from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated + Dataset = collections.namedtuple('Dataset', ['data', 'target']) Datasets = collections.namedtuple('Datasets', ['train', 'validation', 'test']) +@deprecated(None, 'Use tf.data instead.') def load_csv_with_header(filename, target_dtype, features_dtype, @@ -53,6 +62,7 @@ def load_csv_with_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def load_csv_without_header(filename, target_dtype, features_dtype, @@ -70,6 +80,7 @@ def load_csv_without_header(filename, return Dataset(data=data, target=target) +@deprecated(None, 'Use tf.data instead.') def shrink_csv(filename, ratio): """Create a smaller dataset of only 1/ratio of original data.""" filename_small = filename.replace('.', '_small.') @@ -84,6 +95,7 @@ def shrink_csv(filename, ratio): i += 1 +@deprecated(None, 'Use scikits.learn.datasets.') def load_iris(data_path=None): """Load Iris dataset. @@ -100,6 +112,7 @@ def load_iris(data_path=None): data_path, target_dtype=np.int, features_dtype=np.float) +@deprecated(None, 'Use scikits.learn.datasets.') def load_boston(data_path=None): """Load Boston housing dataset. @@ -116,7 +129,12 @@ def load_boston(data_path=None): data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): +@deprecated(None, 'Use the retry module or similar alternatives.') +def retry(initial_delay, + max_delay, + factor=2.0, + jitter=0.25, + is_retriable=None): """Simple decorator for wrapping retriable functions. Args: @@ -152,7 +170,7 @@ def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): for delay in delays(): try: return fn(*args, **kwargs) - except Exception as e: # pylint: disable=broad-except) + except Exception as e: # pylint: disable=broad-except if is_retriable is None: continue @@ -176,11 +194,13 @@ def _is_retriable(e): return isinstance(e, IOError) and e.errno in _RETRIABLE_ERRNOS +@deprecated(None, 'Please use urllib or similar directly.') @retry(initial_delay=1.0, max_delay=16.0, is_retriable=_is_retriable) def urlretrieve_with_retry(url, filename=None): return urllib.request.urlretrieve(url, filename) +@deprecated(None, 'Please write your own downloading logic.') def maybe_download(filename, work_directory, source_url): """Download the data from source url, unless it's already here. diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 37f9175015a239f763c7721cf36ab8063c0a3e32..abbb44c2f5b701829ce16f64eadd8ebc04c84e2c 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functions for downloading and reading MNIST data.""" +"""Functions for downloading and reading MNIST data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -27,6 +32,7 @@ from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated # CVDF mirror of http://yann.lecun.com/exdb/mnist/ DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/' @@ -37,6 +43,7 @@ def _read32(bytestream): return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_images(f): """Extract the images into a 4D uint8 numpy array [index, y, x, depth]. @@ -65,6 +72,7 @@ def extract_images(f): return data +@deprecated(None, 'Please use tf.one_hot on tensors.') def dense_to_one_hot(labels_dense, num_classes): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] @@ -74,6 +82,7 @@ def dense_to_one_hot(labels_dense, num_classes): return labels_one_hot +@deprecated(None, 'Please use tf.data to implement this functionality.') def extract_labels(f, one_hot=False, num_classes=10): """Extract the labels into a 1D uint8 numpy array [index]. @@ -103,7 +112,15 @@ def extract_labels(f, one_hot=False, num_classes=10): class DataSet(object): + """Container class for a dataset (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def __init__(self, images, labels, @@ -210,6 +227,8 @@ class DataSet(object): return self._images[start:end], self._labels[start:end] +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def read_data_sets(train_dir, fake_data=False, one_hot=False, @@ -275,5 +294,7 @@ def read_data_sets(train_dir, return base.Datasets(train=train, validation=validation, test=test) +@deprecated(None, 'Please use alternatives such as official/mnist/dataset.py' + ' from tensorflow/models.') def load_mnist(train_dir='MNIST-data'): return read_data_sets(train_dir) diff --git a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py index 6e0ba38941ce4650ede9f7210e284bde2ed8e6a9..a4848fa64a72f031ef35c0c3256e97a7326acd60 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/produce_small_datasets.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Produce DBpedia datasets of a smaller size.""" +"""Produce DBpedia datasets of a smaller size (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py index 9a843168c27d9cae3f55efe4fe4c688d86c745f3..6a0e3350b3d1052249160a2a997a76de7a5040c3 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Synthetic dataset generators.""" +"""Synthetic dataset generators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,8 +26,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.learn.python.learn.datasets.base import Dataset +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def circles(n_samples=100, noise=None, seed=None, @@ -93,6 +100,7 @@ def circles(n_samples=100, return Dataset(data=X[indices], target=y[indices]) +@deprecated(None, 'Consider using synthetic datasets from scikits.learn.') def spirals(n_samples=100, noise=None, seed=None, diff --git a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py index 2596a2ecaf1572506504831e8b08fab9b5dbc119..ce9466301728082f8e9d99c90989ba8fe623bcf0 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py +++ b/tensorflow/contrib/learn/python/learn/datasets/text_datasets.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Text datasets.""" +"""Text datasets (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,10 +31,12 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated DBPEDIA_URL = 'https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz' +@deprecated(None, 'See contrib/learn/README.md') def maybe_download_dbpedia(data_dir): """Download if DBpedia data is not present.""" train_path = os.path.join(data_dir, 'dbpedia_csv/train.csv') @@ -41,6 +48,7 @@ def maybe_download_dbpedia(data_dir): tfile.extractall(data_dir) +@deprecated(None, 'See contrib/learn/README.md') def load_dbpedia(size='small', test_with_fake_data=False): """Get DBpedia datasets from CSV files.""" if not test_with_fake_data: diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py index 4981750c94c7ac31e23b7a3f71ca30e3c9573a20..3e64595f312bcc2a2e8dcba589fb993a249b684b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py +++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""An estimator is a rule for calculating an estimate of a given quantity. +"""An estimator is a rule for calculating an estimate of a given quantity (deprecated). + +These classes are deprecated and replaced with `tf.estimator`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. # Estimators diff --git a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py index 15277415a1ce83dc1d4a334e60fe1933ba244df0..1f0e4663d060a3850e2002b27f809fde1db47e48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/_sklearn.py @@ -13,7 +13,7 @@ # limitations under the License. # ============================================================================== -"""sklearn cross-support.""" +"""sklearn cross-support (deprecated).""" from __future__ import absolute_import from __future__ import division @@ -132,6 +132,8 @@ class _TransformerMixin(): class NotFittedError(ValueError, AttributeError): """Exception class to raise if estimator is used before fitting. + USE OF THIS EXCEPTION IS DEPRECATED. + This class inherits from both ValueError and AttributeError to help with exception handling and backward compatibility. diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py index a02c726c74946d93b8e1726473db746220b00795..1fa58271e2b886cd143683a759145fd750791473 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TensorFlow composable models used as building blocks for estimators.""" +"""TensorFlow composable models used as building blocks for estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,6 +39,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import variable_scope from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated class _ComposableModel(object): @@ -46,6 +52,7 @@ class _ComposableModel(object): _ComposableModel and its subclasses are not part of the public tf.learn API. """ + @deprecated(None, "Please use model_fns in tf.estimator.") def __init__(self, num_label_columns, optimizer, @@ -141,6 +148,10 @@ class _ComposableModel(object): class LinearComposableModel(_ComposableModel): """A _ComposableModel that implements linear regression. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ @@ -252,6 +263,10 @@ class LinearComposableModel(_ComposableModel): class DNNComposableModel(_ComposableModel): """A _ComposableModel that implements a DNN. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Instances of this class can be used to build estimators through the use of composition. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/constants.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py index fc69e810244a182b864be856e6720f8584f7aa65..d2548946bc77dea7c452d61c7e2b6e12c3d6239a 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/constants.py +++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py @@ -13,9 +13,11 @@ # limitations under the License. # ============================================================================== -"""Constants regarding Estimators. +"""Constants regarding Estimators (deprecated). -This file is obsoleted in the move of Estimator to core. +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ from __future__ import absolute_import from __future__ import division @@ -25,6 +27,8 @@ from __future__ import print_function class ProblemType(object): """Enum-like values for the type of problem that the model solves. + THIS CLASS IS DEPRECATED. + These values are used when exporting the model to produce the appropriate signature function for serving. diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug.py b/tensorflow/contrib/learn/python/learn/estimators/debug.py index 9d5f6c2bf969d7c85d251bf1b06a0307a41b2297..24b067b7e38b12df3d1d0c49f626344217218571 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Debug estimators. +"""Debug estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Debug estimators are bias-only estimators that can be used for debugging and as simple baselines. @@ -118,6 +122,10 @@ def debug_model_fn(features, labels, mode, params, config=None): class DebugClassifier(estimator.Estimator): """A classifier for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -237,6 +245,10 @@ class DebugClassifier(estimator.Estimator): class DebugRegressor(estimator.Estimator): """A regressor for TensorFlow Debug models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index c17b41c0f767e19d9c3635a8f60347a49b297cfb..eabebb7e881558471c343c0573cc9a8f4a425312 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Deep Neural Network estimators.""" +"""Deep Neural Network estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -212,6 +217,10 @@ def _dnn_model_fn(features, labels, mode, params, config=None): class DNNClassifier(estimator.Estimator): """A classifier for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -521,6 +530,10 @@ class DNNClassifier(estimator.Estimator): class DNNRegressor(estimator.Estimator): """A regressor for TensorFlow DNN models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python @@ -796,6 +809,10 @@ class DNNRegressor(estimator.Estimator): class DNNEstimator(estimator.Estimator): """A Estimator for TensorFlow DNN models with user specified _Head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Example: ```python diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 726612235050def6e7addb503cc6646a25de0e42..3d85533d92d17095bae9a69f229171e1bf61ba10 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow estimators for Linear and DNN joined training models.""" +"""TensorFlow estimators for Linear and DNN joined training models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -372,6 +377,10 @@ def _dnn_linear_combined_model_fn(features, labels, mode, params, config=None): class DNNLinearCombinedEstimator(estimator.Estimator): """An estimator for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -490,6 +499,10 @@ class DNNLinearCombinedEstimator(estimator.Estimator): class DNNLinearCombinedClassifier(estimator.Estimator): """A classifier for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. @@ -832,6 +845,10 @@ class DNNLinearCombinedClassifier(estimator.Estimator): class DNNLinearCombinedRegressor(estimator.Estimator): """A regressor for TensorFlow Linear and DNN joined training models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note: New users must set `fix_global_step_increment_bug=True` when creating an estimator. diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 69440e823ef1ed2d739f28bc14587891f2de80bb..a703dc66e922d48ceb64edc2a979061b8e45db49 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for Dynamic RNNs.""" +"""Estimator for Dynamic RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -540,6 +545,12 @@ def _get_dynamic_rnn_model_fn( class DynamicRnnEstimator(estimator.Estimator): + """Dynamically unrolled RNN (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 4b63e08ab3372849309ee5d28d754de82e9632f4..5262e04e16ee85d1672dd495f05084ff07c8dd18 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Base Estimator class.""" +"""Base Estimator class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -138,6 +143,7 @@ def _get_input_fn(x, y, input_fn, feed_fn, batch_size, shuffle=False, epochs=1): return df.input_builder, df.get_feed_dict_fn() +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input_fn(input_fn): """Creates `FeatureColumn` objects for inputs defined by `input_fn`. @@ -158,6 +164,7 @@ def infer_real_valued_columns_from_input_fn(input_fn): return layers.infer_real_valued_columns(features) +@deprecated(None, 'Please specify feature columns explicitly.') def infer_real_valued_columns_from_input(x): """Creates `FeatureColumn` objects for inputs defined by input `x`. @@ -389,6 +396,10 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Users should not instantiate or subclass this class. Instead, use an `Estimator`. """ @@ -399,6 +410,8 @@ class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, # TODO(wicke): Remove this once launcher takes over config functionality _Config = run_config.RunConfig # pylint: disable=invalid-name + @deprecated(None, 'Please replace uses of any Estimator from tf.contrib.learn' + ' with an Estimator from tf.estimator.*') def __init__(self, model_dir=None, config=None): """Initializes a BaseEstimator instance. @@ -1074,6 +1087,10 @@ def _identity_feature_engineering_fn(features, labels): class Estimator(BaseEstimator): """Estimator class is the basic TensorFlow model trainer/evaluator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ def __init__(self, @@ -1458,8 +1475,14 @@ class Estimator(BaseEstimator): # For time of deprecation x,y from Estimator allow direct access. # pylint: disable=protected-access class SKCompat(sklearn.BaseEstimator): - """Scikit learn wrapper for TensorFlow Learn Estimator.""" + """Scikit learn wrapper for TensorFlow Learn Estimator. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please switch to the Estimator interface.') def __init__(self, estimator): self._estimator = estimator diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py index fd47710e3015de9ae6a453f98978b0ef8f88968c..e4c31396baf8271c49395a2b87b454dbc77177e2 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test_utils.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utils for Estimator.""" +"""Utils for Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 9b124b2c19f16bbc9b2afeadb82a32006e1a0ae9..2b4b6eff39f4fc8a20a149edfc07d2f4f27a9bae 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -12,8 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Abstractions for the head(s) of a model. +"""Abstractions for the head(s) of a model (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function @@ -47,11 +52,16 @@ from tensorflow.python.summary import summary from tensorflow.python.training import training from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated class Head(object): """Interface for the head/top of a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Given logits (or output of a hidden layer), a Head knows how to compute predictions, loss, default metric and export signature. It is meant to, @@ -177,6 +187,7 @@ class Head(object): raise NotImplementedError("Calling an abstract method.") +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -216,6 +227,7 @@ def regression_head(label_name=None, link_fn=(link_fn if link_fn is not None else array_ops.identity)) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def poisson_regression_head(label_name=None, weight_column_name=None, label_dimension=1, @@ -254,6 +266,7 @@ def poisson_regression_head(label_name=None, # TODO(zakaria): Consider adding a _RegressionHead for logistic_regression +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_class_head(n_classes, label_name=None, weight_column_name=None, @@ -335,6 +348,7 @@ def multi_class_head(n_classes, label_keys=label_keys) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def binary_svm_head( label_name=None, weight_column_name=None, @@ -370,6 +384,7 @@ def binary_svm_head( thresholds=thresholds) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_label_head(n_classes, label_name=None, weight_column_name=None, @@ -430,6 +445,7 @@ def multi_label_head(n_classes, loss_fn=_wrap_custom_loss_fn(loss_fn) if loss_fn else None) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def loss_only_head(loss_fn, head_name=None): """Creates a Head that contains only loss terms. @@ -447,6 +463,7 @@ def loss_only_head(loss_fn, head_name=None): return _LossOnlyHead(loss_fn, head_name=head_name) +@deprecated(None, "Please switch to tf.contrib.estimator.*_head.") def multi_head(heads, loss_weights=None): """Creates a MultiHead stemming from same logits/hidden layer. @@ -479,6 +496,7 @@ def multi_head(heads, loss_weights=None): return _MultiHead(heads, loss_merger=_weighted_loss_merger) +@deprecated(None, "Use 'lambda _: tf.no_op()'.") def no_op_train_fn(loss): del loss return control_flow_ops.no_op() diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 7c2d9bb0767cb979dae9c84b5342d129225677ed..6d5da81b4c2087fb9c5307902e452a6220a17cd0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -362,7 +362,7 @@ class MultiLabelHeadTest(test.TestCase): "auc_precision_recall": 0.166667, "auc_precision_recall/class0": 0, "auc_precision_recall/class1": 0., - "auc_precision_recall/class2": 1., + "auc_precision_recall/class2": 0.49999, "labels/actual_label_mean/class0": self._labels[0][0], "labels/actual_label_mean/class1": self._labels[0][1], "labels/actual_label_mean/class2": self._labels[0][2], @@ -748,7 +748,7 @@ class BinaryClassificationHeadTest(test.TestCase): "accuracy/baseline_label_mean": label_mean, "accuracy/threshold_0.500000_mean": 1. / 2, "auc": 1. / 2, - "auc_precision_recall": 0.749999, + "auc_precision_recall": 0.25, "labels/actual_label_mean": label_mean, "labels/prediction_mean": .731059, # softmax "loss": expected_loss, diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 8f9d6fc318a357853bdb8e3264f6691b410006b1..66ebcfd1d81904b9afe5be6bd1a648fe325e1e0b 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API. +"""Implementation of k-means clustering on top of `Estimator` API (deprecated). This module is deprecated. Please use @{tf.contrib.factorization.KMeansClustering} instead of @@ -153,7 +153,12 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): # TODO(agarwal,ands): support sharded input. class KMeansClustering(estimator.Estimator): - """An Estimator for K-Means clustering.""" + """An Estimator for K-Means clustering. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE RANDOM_INIT = clustering_ops.RANDOM_INIT diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 37aa8b339622415d082933cdf66d2472a4119b48..64d7ecc68e7abb1d36a3eb098fedd8184d6e9d77 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Linear Estimators.""" +"""Linear Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -305,6 +310,10 @@ class _SdcaUpdateWeightsHook(session_run_hook.SessionRunHook): class LinearClassifier(estimator.Estimator): """Linear classifier model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear model to classify instances into one of multiple possible classes. When number of possible classes is 2, this is binary classification. @@ -625,6 +634,10 @@ class LinearClassifier(estimator.Estimator): class LinearRegressor(estimator.Estimator): """Linear regressor model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a linear regression model to predict label value given observation of feature values. @@ -860,6 +873,10 @@ class LinearRegressor(estimator.Estimator): class LinearEstimator(estimator.Estimator): """Linear model with user specified head. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Train a generalized linear model to predict label value given observation of feature values. diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py index fb339160d58e09d4ffd50090f2dbbcec08bebe47..3cbcc6e98de1c915c302617e4591c9baa33adeaf 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Logistic regression (aka binary classifier) class. +"""Logistic regression (aka binary classifier) class (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This defines some useful basic metrics for using logistic regression to classify a binary event (0 vs 1). @@ -75,6 +79,10 @@ def LogisticRegressor( # pylint: disable=invalid-name feature_engineering_fn=None): """Builds a logistic regression Estimator for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This method provides a basic Estimator with some additional metrics for custom binary classification models, including AUC, precision/recall and accuracy. diff --git a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py index 99388f116b345bd038f2985606c6203011597ea2..f264248e44d9aa48f26ee32e36746bd4c3145a8d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/metric_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/metric_key.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for metric keys.""" +"""Enum for metric keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function class MetricKey(object): - """Metric key strings.""" + """Metric key strings (deprecated).""" + LOSS = "loss" AUC = "auc" AUC_PR = "auc_precision_recall" diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 44e6c7c52dac524a22e9099e33e2aef82f8fe7ba..dcb161180c99ce71195c820217e8bdaf79d70901 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Classes and methods related to model_fn.""" +"""Classes and methods related to model_fn (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -37,10 +42,13 @@ from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import session_run_hook +from tensorflow.python.util.deprecation import deprecated class ModeKeys(object): - """Standard names for model modes. + """Standard names for model modes (deprecated). + + THIS CLASS IS DEPRECATED. The following standard keys are defined: @@ -65,8 +73,16 @@ class ModelFnOps( 'output_alternatives', 'training_chief_hooks', 'training_hooks', 'scaffold', 'mode' ])): - """Ops returned from a model_fn.""" + """Ops returned from a model_fn. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'When switching to tf.estimator.Estimator, use ' + 'tf.estimator.EstimatorSpec. You can use the `estimator_spec`' + ' method to create an equivalent one.') def __new__(cls, mode, predictions=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py index f8d87b8914307a86eb2f46123a28ff11eb925eda..6fd2fc9d592cef4e44a640e2f27cb28b367d44d5 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py +++ b/tensorflow/contrib/learn/python/learn/estimators/prediction_key.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Enum for model prediction keys. +"""Enum for model prediction keys (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This file is obsoleted in the move of Estimator to core. """ @@ -22,6 +26,8 @@ from __future__ import print_function class PredictionKey(object): + """THIS CLASS IS DEPRECATED.""" + CLASSES = "classes" PROBABILITIES = "probabilities" LOGITS = "logits" diff --git a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py index 2752bc2d90ee0f51b2c40ccc4d24a4eb21cff38f..215022e5d9e5d3cd5d6a96583b325b19a1719568 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py +++ b/tensorflow/contrib/learn/python/learn/estimators/rnn_common.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Common operations for RNN Estimators.""" +"""Common operations for RNN Estimators (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/estimators/run_config.py b/tensorflow/contrib/learn/python/learn/estimators/run_config.py index fd90fd1cc6277e7d80287aefdbab6134dac7c0d5..1d161093de01ef838d0c75ec9a39574c7529bd57 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/run_config.py +++ b/tensorflow/contrib/learn/python/learn/estimators/run_config.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Run Config.""" +"""Run Config (deprecated, use tf.estimator.RunConfig instead). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -29,11 +34,12 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as core_run_config from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import server_lib +from tensorflow.python.util.deprecation import deprecated # A list of the property names in RunConfig user allows to change. They will # not affect the execution framework, so when execution framework checks the -# `uid` of the RunConfig, it should be ingored. +# `uid` of the RunConfig, it should be ignored. _DEFAULT_UID_WHITE_LIST = [ 'tf_random_seed', 'save_summary_steps', @@ -47,6 +53,7 @@ _DEFAULT_UID_WHITE_LIST = [ class Environment(object): + """DEPRECATED CLASS.""" # For running general distributed training. CLOUD = 'cloud' # For running Google-internal distributed training. @@ -56,6 +63,7 @@ class Environment(object): class TaskType(object): + """DEPRECATED CLASS.""" MASTER = 'master' PS = 'ps' WORKER = 'worker' @@ -64,6 +72,8 @@ class TaskType(object): class ClusterConfig(object): """This class specifies the configurations for a distributed run. + THIS CLASS IS DEPRECATED. Use tf.estimator.RunConfig instead. + If you're using an `Estimator`, you should probably use the subclass RunConfig instead. """ @@ -211,10 +221,13 @@ class ClusterConfig(object): class RunConfig(ClusterConfig, core_run_config.RunConfig): """This class specifies the configurations for an `Estimator` run. - This class is the implementation of @{tf.estimator.RunConfig} interface. + This class is a deprecated implementation of @{tf.estimator.RunConfig} + interface. """ _USE_DEFAULT = 0 + @deprecated(None, 'When switching to tf.estimator.Estimator, use' + ' tf.estimator.RunConfig instead.') def __init__(self, master=None, num_cores=0, diff --git a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py index 0cea35e219a4457417a161a3ac4ac4292fd24f53..de78c72c3ae3ef14f5f7c46b1d47f82e8266c7c6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/state_saving_rnn_estimator.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Estimator for State Saving RNNs.""" +"""Estimator for State Saving RNNs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -528,6 +533,12 @@ def _get_rnn_model_fn(cell_type, class StateSavingRnnEstimator(estimator.Estimator): + """RNN with static unrolling and state saving (deprecated). + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, problem_type, diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index 72920d73c0c92886e54f533ad7fe170fe27d9870..3459997baba16fc0d4045e50819ecdd0e7121657 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support Vector Machine (SVM) Estimator.""" +"""Support Vector Machine (SVM) Estimator (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -36,6 +41,10 @@ def _as_iterable(preds, output): class SVM(estimator.Estimator): """Support Vector Machine (SVM) model for binary classification. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Currently, only linear SVMs are supported. For the underlying optimization problem, the `SDCAOptimizer` is used. For performance and convergence tuning, the num_loss_partitions parameter passed to `SDCAOptimizer` (see `__init__()` diff --git a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py index a120bc6cc3975a3d4559d018c8aa74ff42a16d2d..71b5658dd174d2b47e33860844359f68e6768026 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py +++ b/tensorflow/contrib/learn/python/learn/estimators/tensor_signature.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorSignature class and utilities.""" +"""TensorSignature class and utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -33,6 +38,10 @@ class TensorSignature(collections.namedtuple( "TensorSignature", ["dtype", "shape", "is_sparse"])): """Signature of the `Tensor` object. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Useful to check compatibility of tensors. Example: diff --git a/tensorflow/contrib/learn/python/learn/estimators/test_data.py b/tensorflow/contrib/learn/python/learn/estimators/test_data.py index ed201bfc58f273e6587850032386c2686aea4148..e4b057b4f5a9e081c2d891bd9828ffc315e51e91 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/test_data.py +++ b/tensorflow/contrib/learn/python/learn/estimators/test_data.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Test data utilities.""" +"""Test data utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 8f6cd39864b437f163dd7c1140dc88755ce98529..10881ca885599bc81386e15f814a2687d907f63b 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Evaluable` interface.""" +"""`Evaluable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,10 @@ import abc class Evaluable(object): """Interface for objects that are evaluatable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index bec976afd2719138117976381669ca3292360480..9a7c4cd685b90cf3ac8922bdb031aa935c1aa64f 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Experiment class collecting information needed for a single training run.""" +"""Experiment class collecting information for a single training run (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -25,7 +30,6 @@ import os import time from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import export_strategy @@ -118,6 +122,10 @@ class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): class Experiment(object): """Experiment is a class containing all information needed to train a model. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + After an experiment is created (by passing an Estimator and inputs for training and evaluation), an Experiment instance knows how to invoke training and eval loops in a sensible fashion for distributed training. @@ -125,16 +133,8 @@ class Experiment(object): # TODO(ispir): remove delay_workers_by_global_step and make global step based # waiting as only behavior. - @deprecated_args( - "2016-10-23", - "local_eval_frequency is deprecated as local_run will be renamed to " - "train_and_evaluate. Use min_eval_frequency and call train_and_evaluate " - "instead. Note, however, that the default for min_eval_frequency is 1, " - "meaning models will be evaluated every time a new checkpoint is " - "available. In contrast, the default for local_eval_frequency is None, " - "resulting in evaluation occurring only after training has completed. " - "min_eval_frequency is ignored when calling the deprecated local_run.", - "local_eval_frequency") + @deprecated(None, "Please switch to tf.estimator.train_and_evaluate. You will" + " also have to convert to a tf.estimator.Estimator.") def __init__(self, estimator, train_input_fn, @@ -152,7 +152,8 @@ class Experiment(object): export_strategies=None, train_steps_per_iteration=None, checkpoint_and_export=False, - saving_listeners=None): + saving_listeners=None, + check_interval_secs=5): """Constructor for `Experiment`. Creates an Experiment instance. None of the functions passed to this @@ -190,8 +191,9 @@ class Experiment(object): number of steps between evaluations. Of course, evaluation does not occur if no new snapshot is available, hence, this is the minimum. If 0, the evaluation will only happen after training. - If None, defaults to 1, unless model_dir is on GCS, in which case the - default is 1000. + If None, defaults to 1. To avoid checking for new checkpoints too + frequent, the interval is further limited to be at least + check_interval_secs between checks. delay_workers_by_global_step: if `True` delays training workers based on global step instead of time. export_strategies: Iterable of `ExportStrategy`s, or a single one, or @@ -215,7 +217,10 @@ class Experiment(object): saving_listeners: list of `CheckpointSaverListener` objects. Used by tf.estimator.Estimator for callbacks that run immediately before or after checkpoint savings. - + check_interval_secs: + Minimum time between subsequent checks for a new checkpoint. This + mostly applies if both min_eval_frequency and the time spent per + training step is low. Raises: ValueError: if `estimator` does not implement Estimator interface, or if export_strategies has the wrong type. @@ -261,13 +266,9 @@ class Experiment(object): self._continuous_eval_throttle_secs = continuous_eval_throttle_secs self._checkpoint_and_export = checkpoint_and_export self._saving_listeners = saving_listeners - # Using 1 on a non-cached file system requires a lot of overhead to - # read the checkpoint state file. This is particular bad on GCS, so - # we use a different default. This is a temporary band-aid, to be - # fixed holistically later (b/36498507). - default_min_eval_frequency = 1000 if _is_gcs(estimator.model_dir) else 1 self._min_eval_frequency = min_eval_frequency if ( - min_eval_frequency is not None) else default_min_eval_frequency + min_eval_frequency is not None) else 1 + self._check_interval_secs = check_interval_secs self._delay_workers_by_global_step = delay_workers_by_global_step self._train_monitors = train_monitors[:] if train_monitors else [] self._eval_hooks = eval_hooks[:] if eval_hooks else [] @@ -646,12 +647,19 @@ class Experiment(object): self._train_monitors += [saver_hook] else: if self._min_eval_frequency: + # Using low min_eval_frequency (default is 1) on a non-cached file + # system requires a lot of overhead to read the checkpoint state file. + # This is particular bad on GCS and CNS. See also b/36498507 for + # context. `check_interval_secs = 5` avoids polling a remote + # fileystem too often. + self._train_monitors += [ monitors.ValidationMonitor( input_fn=self._eval_input_fn, eval_steps=self._eval_steps, metrics=self._eval_metrics, every_n_steps=self._min_eval_frequency, + check_interval_secs=self._check_interval_secs, name=eval_dir_suffix, hooks=self._eval_hooks) ] @@ -928,7 +936,3 @@ def _new_attr_context(obj, attr): yield finally: setattr(obj, attr, saved) - - -def _is_gcs(model_dir): - return model_dir and model_dir.startswith("gs://") diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index 545d7d8924c0c10544e6113e2968b7ae3d2090fc..d10927a0cdd5c67c8d2a8e569153235ee175ec4d 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -674,37 +674,11 @@ class ExperimentTest(test.TestCase): def test_min_eval_frequency_defaults(self): def dummy_model_fn(features, labels): # pylint: disable=unused-argument pass - - # The default value when model_dir is on GCS is 1000 - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, train_input_fn=None, eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 1000) - - # The default value when model_dir is not on GCS is 1 estimator = core_estimator.Estimator(dummy_model_fn, '/tmp/dummy') ex = experiment.Experiment( estimator, train_input_fn=None, eval_input_fn=None) self.assertEquals(ex._min_eval_frequency, 1) - # Make sure default not used when explicitly set - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, - min_eval_frequency=123, - train_input_fn=None, - eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 123) - - # Make sure default not used when explicitly set as 0 - estimator = core_estimator.Estimator(dummy_model_fn, 'gs://dummy_bucket') - ex = experiment.Experiment( - estimator, - min_eval_frequency=0, - train_input_fn=None, - eval_input_fn=None) - self.assertEquals(ex._min_eval_frequency, 0) - def test_continuous_train_and_eval(self): for est in self._estimators_for_tests(eval_dict={'global_step': 100}): if isinstance(est, core_estimator.Estimator): diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index 55a8b824312b89e0ac66513242191f4201ac212a..075cab536ecb5279e7e6f23abb0b70c75043a7ec 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""ExportStrategy class represents different flavors of model export.""" +"""ExportStrategy class represents different flavors of model export (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,7 @@ from __future__ import print_function import collections from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated __all__ = ['ExportStrategy'] @@ -30,6 +36,10 @@ class ExportStrategy( ['name', 'export_fn', 'strip_default_attrs'])): """A class representing a type of model export. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Typically constructed by a utility function specific to the exporter, such as `saved_model_export_utils.make_export_strategy()`. @@ -56,6 +66,8 @@ class ExportStrategy( forward compatibility of the resulting `SavedModel`. """ + @deprecated(None, 'Please switch to tf.estimator.train_and_evaluate, and use ' + 'tf.estimator.Exporter.') def __new__(cls, name, export_fn, strip_default_attrs=None): return super(ExportStrategy, cls).__new__( cls, name, export_fn, strip_default_attrs) diff --git a/tensorflow/contrib/learn/python/learn/graph_actions.py b/tensorflow/contrib/learn/python/learn/graph_actions.py index 98365c05f663e5d2a06703457fc5663d7135f7d9..a997fab723a16dddf150aa9397863605e4e77933 100644 --- a/tensorflow/contrib/learn/python/learn/graph_actions.py +++ b/tensorflow/contrib/learn/python/learn/graph_actions.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""High level operations on graphs.""" +"""High level operations on graphs (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -68,6 +73,7 @@ def clear_summary_writers(): return summary_io.SummaryWriterCache.clear() +@deprecated(None, 'Use `SummaryWriterCache.get` directly.') def get_summary_writer(logdir): """Returns single SummaryWriter per logdir in current run. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py index 06c3782a471537cf3879450e6bd20899a35d96ac..8b133a4440d8cbc19abca64f972791fc16ade6f8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/__init__.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/__init__.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tools to allow different io formats.""" +"""Tools to allow different io formats (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py index 7d666391cea3c0a52a2cb7e324c00d5f480710d5..e0a1948d95a727675dac8ff3ce9f55c35d5f8d8d 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/dask_io.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Methods to allow dask.DataFrame.""" +"""Methods to allow dask.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -21,6 +26,8 @@ from __future__ import print_function import numpy as np +from tensorflow.python.util.deprecation import deprecated + try: # pylint: disable=g-import-not-at-top import dask.dataframe as dd @@ -60,6 +67,7 @@ def _construct_dask_df_with_divisions(df): return dd.Series(merge(dsk, df.dask), name, df.name, divisions) +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_data(data): """Extract data from dask.Series or dask.DataFrame for predictors. @@ -81,6 +89,7 @@ def extract_dask_data(data): return data +@deprecated(None, 'Please feed input to tf.data to support dask.') def extract_dask_labels(labels): """Extract data from dask.Series or dask.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 96be8b1bc402479d5611965f27abb197363cb939..c45b1d186471125776d6536112aebb66bb5ad558 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementations of different data feeders to provide data for TF trainer.""" +"""Implementations of different data feeders to provide data for TF trainer (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" # TODO(ipolosukhin): Replace this module with feed-dict queue runners & queues. @@ -31,6 +36,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels @@ -101,6 +107,7 @@ def _is_iterable(x): return hasattr(x, 'next') or hasattr(x, '__next__') +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_train_data_feeder(x, y, n_classes, @@ -188,6 +195,7 @@ def _batch_data(x, batch_size=None): yield np.matrix(chunk) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_predict_data_feeder(x, batch_size=None): """Returns an iterable for feeding into predict step. @@ -219,6 +227,7 @@ def setup_predict_data_feeder(x, batch_size=None): return [x] +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def setup_processor_data_feeder(x): """Sets up processor iterable. @@ -233,6 +242,7 @@ def setup_processor_data_feeder(x): return x +@deprecated(None, 'Please convert numpy dtypes explicitly.') def check_array(array, dtype): """Checks array on dtype and converts it if different. @@ -275,8 +285,14 @@ def _check_dtype(dtype): class DataFeeder(object): - """Data feeder is an example class to sample data for TF trainer.""" + """Data feeder is an example class to sample data for TF trainer. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, x, y, @@ -563,6 +579,10 @@ class DataFeeder(object): class StreamingDataFeeder(DataFeeder): """Data feeder for TF trainer that reads data from iterator. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Streaming data feeder allows to read data as it comes it from disk or somewhere else. It's custom to have this iterators rotate infinetly over the dataset, to allow control of how much to learn on the trainer side. @@ -771,11 +791,16 @@ class StreamingDataFeeder(DataFeeder): class DaskDataFeeder(object): """Data feeder for that reads data from dask.Series and dask.DataFrame. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Numpy arrays can be serialized to disk and it's possible to do random seeks into them. DaskDataFeeder will remove requirement to have full dataset in the memory and still do random seeks for sampling of batches. """ + @deprecated(None, 'Please feed input to tf.data to support dask.') def __init__(self, x, y, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py index 884faf8335e2a3ca1d27d2d93b4c817131648774..f8aaa0c9e3e5b589a6ad47678dba3dc38de7c471 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/generator_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow generator of dict with numpy arrays.""" +"""Methods to allow generator of dict with numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,8 +28,10 @@ from types import FunctionType from types import GeneratorType from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.data.') def generator_input_fn(x, target_key=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 3a46c239688017f9204d2c6182a6f81cd325a417..9e816f54b6cf8dee84c6d62406ab3db700054d06 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to read data in the graph.""" +"""Methods to read data in the graph (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -34,11 +39,13 @@ from tensorflow.python.platform import gfile from tensorflow.python.summary import summary from tensorflow.python.training import input as input_ops from tensorflow.python.training import queue_runner +from tensorflow.python.util.deprecation import deprecated # Default name for key in the feature dict. KEY_FEATURE_NAME = '__key__' +@deprecated(None, 'Use tf.data.') def read_batch_examples(file_pattern, batch_size, reader, @@ -106,6 +113,7 @@ def read_batch_examples(file_pattern, return examples +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples(file_pattern, batch_size, reader, @@ -175,6 +183,7 @@ def read_keyed_batch_examples(file_pattern, seed=seed) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_examples_shared_queue(file_pattern, batch_size, reader, @@ -452,6 +461,7 @@ def _read_keyed_batch_examples_helper(file_pattern, return queued_examples_with_keys +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features(file_pattern, batch_size, features, @@ -540,6 +550,7 @@ def read_keyed_batch_features(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def read_keyed_batch_features_shared_queue(file_pattern, batch_size, features, @@ -620,6 +631,7 @@ def read_keyed_batch_features_shared_queue(file_pattern, name=scope) +@deprecated(None, 'Use tf.data.') def queue_parsed_features(parsed_features, keys=None, feature_queue_capacity=100, @@ -742,6 +754,7 @@ def queue_parsed_features(parsed_features, return dequeued_keys, dequeued_parsed_features +@deprecated(None, 'Use tf.data.') def read_batch_features(file_pattern, batch_size, features, @@ -821,6 +834,7 @@ def read_batch_features(file_pattern, return features +@deprecated(None, 'Use tf.data.') def read_batch_record_features(file_pattern, batch_size, features, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py index 692438807fbd7febb156d4db73b5d3deba6c987d..29552d24f1eaa0d85a99c8b09f69d007e7e4fe9f 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io.py @@ -12,15 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Methods to allow dict of numpy arrays.""" +"""Methods to allow dict of numpy arrays (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn as core_numpy_input_fn +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Use tf.estimator.inputs.numpy_input_fn.') def numpy_input_fn(x, y=None, batch_size=128, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py index ede7558eafa9237dc63aa95a62e599c5e9755822..b4ef055f5ae484ec704ad42efcf2c00c4a7a4f56 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/pandas_io.py @@ -13,13 +13,19 @@ # limitations under the License. # ============================================================================== -"""Methods to allow pandas.DataFrame.""" +"""Methods to allow pandas.DataFrame (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn as core_pandas_input_fn +from tensorflow.python.util.deprecation import deprecated try: # pylint: disable=g-import-not-at-top @@ -47,6 +53,7 @@ PANDAS_DTYPES = { } +@deprecated(None, 'Please use tf.estimator.inputs.pandas_input_fn') def pandas_input_fn(x, y=None, batch_size=128, @@ -66,6 +73,7 @@ def pandas_input_fn(x, target_column=target_column) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_data(data): """Extract data from pandas.DataFrame for predictors. @@ -96,6 +104,7 @@ def extract_pandas_data(data): 'float, or bool. Found: ' + ', '.join(error_report)) +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_matrix(data): """Extracts numpy matrix from pandas DataFrame. @@ -111,6 +120,7 @@ def extract_pandas_matrix(data): return data.as_matrix() +@deprecated(None, 'Please access pandas data directly.') def extract_pandas_labels(labels): """Extract data from pandas.DataFrame for labels. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 2af723a0d64822e81fa0fbeb106ab812de6ab4e8..d719a3e488b9905ef7903e21d90dbaae0449735c 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Runs an Experiment.""" +"""Runs an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ from tensorflow.contrib.learn.python.learn.estimators import run_config as run_c from tensorflow.contrib.learn.python.learn.experiment import Experiment from tensorflow.contrib.training.python.training import hparam as hparam_lib from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.deprecation import deprecated # TODO(xiejw): Refactor the learn_runner to make code reusable. @@ -99,6 +105,7 @@ def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False): return wrapped_experiment_fn +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def run(experiment_fn, output_dir=None, schedule=None, run_config=None, hparams=None): """Make and run an experiment. @@ -218,6 +225,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, return _execute_schedule(experiment, schedule) +@deprecated(None, 'Use tf.estimator.train_and_evaluate.') def tune(experiment_fn, tuner): """Tune an experiment with hyper-parameters. diff --git a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py index 7d9b1c7716f0ab1f2274ca53406175240b613027..ba2d067787c1dfd4e4820ecc916f1053e9f3cf60 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner_lib.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner_lib.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities to run and tune an Experiment. +"""Utilities to run and tune an Experiment (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@run @@tune diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index 6440bc204b8e339ff51311dcc87b36f556b94092..97220365d5dddb82b602369f06bea021a86d584f 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The metric spec class to flexibly connect models and metrics.""" +"""The metric spec class to flexibly connect models and metrics (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,7 @@ import six from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect +from tensorflow.python.util.deprecation import deprecated def _assert_named_args(sentinel): @@ -223,6 +229,10 @@ def _adapt_metric_fn( class MetricSpec(object): """MetricSpec connects a model to metric functions. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + The MetricSpec class contains all information necessary to connect the output of a `model_fn` to the metrics (usually, streaming metrics) that are used in evaluation. @@ -284,6 +294,7 @@ class MetricSpec(object): """ + @deprecated(None, 'Use tf.estimator.EstimatorSpec.eval_metric_ops.') def __init__(self, metric_fn, prediction_key=None, diff --git a/tensorflow/contrib/learn/python/learn/models.py b/tensorflow/contrib/learn/python/learn/models.py index 4283240d018c949bb35aeb12032d2ee8b75884a5..bd4bbf9f8c9ad7e8a0fc06d8c0dc24672536c158 100644 --- a/tensorflow/contrib/learn/python/learn/models.py +++ b/tensorflow/contrib/learn/python/learn/models.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Various high level TF models.""" +"""Various high level TF models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -28,8 +33,10 @@ from tensorflow.python.ops import array_ops as array_ops_ from tensorflow.python.ops import init_ops from tensorflow.python.ops import variable_scope as vs from tensorflow.python.summary import summary +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Consider using a tf.estimator.LinearRegressor') def linear_regression_zero_init(x, y): """Linear regression subgraph with zero-value initial weights and bias. @@ -43,6 +50,7 @@ def linear_regression_zero_init(x, y): return linear_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.LinearClassifier') def logistic_regression_zero_init(x, y): """Logistic regression subgraph with zero-value initial weights and bias. @@ -56,6 +64,7 @@ def logistic_regression_zero_init(x, y): return logistic_regression(x, y, init_mean=0.0, init_stddev=0.0) +@deprecated(None, 'Consider using a class from tf.estimator.') def linear_regression(x, y, init_mean=None, init_stddev=1.0): """Creates linear regression TensorFlow subgraph. @@ -107,6 +116,7 @@ def linear_regression(x, y, init_mean=None, init_stddev=1.0): return losses_ops.mean_squared_error_regressor(x, y, weights, bias) +@deprecated(None, 'Consider using a class from tf.estimator.') def logistic_regression(x, y, class_weight=None, @@ -203,6 +213,7 @@ def _reverse_seq(input_seq, lengths): return result +@deprecated(None, 'Please consider `tf.nn.bidirectional_dynamic_rnn`.') def bidirectional_rnn(cell_fw, cell_bw, inputs, @@ -283,6 +294,7 @@ def bidirectional_rnn(cell_fw, # End of TensorFlow 0.7 +@deprecated(None, 'Please consider tensorflow/tensor2tensor.') def get_rnn_model(rnn_size, cell_type, num_layers, input_op_fn, bidirectional, target_predictor_fn, sequence_length, initial_state, attn_length, attn_size, attn_vec_size): diff --git a/tensorflow/contrib/learn/python/learn/monitored_session.py b/tensorflow/contrib/learn/python/learn/monitored_session.py index 22602e9f69d972505d83a66a6f9183b5e4d15c44..ac0433f1775feeed2ec3cf49291da01500bef01b 100644 --- a/tensorflow/contrib/learn/python/learn/monitored_session.py +++ b/tensorflow/contrib/learn/python/learn/monitored_session.py @@ -13,7 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A wrapper of Session API which runs hooks.""" +"""A wrapper of Session API which runs hooks (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 51381a7427c919592b8e818c4b46dba974992610..77f7c73d5412d40b338eaff4cf04d99fd0892723 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Monitors instrument the training process. +"""Monitors instrument the training process (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. @@get_default_monitors @@BaseMonitor @@ -59,6 +63,10 @@ from tensorflow.python.util import tf_inspect class BaseMonitor(object): """Base class for Monitors. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Defines basic interfaces of Monitors. Monitors can either be run on all workers or, more commonly, restricted to run exclusively on the elected chief worker. @@ -229,6 +237,10 @@ def _extract_output(outputs, request): class EveryN(BaseMonitor): """Base class for monitors that execute callbacks every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This class adds three new callbacks: - every_n_step_begin - every_n_step_end @@ -418,6 +430,10 @@ class StopAtStep(BaseMonitor): class PrintTensor(EveryN): """Prints given tensors every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This is an `EveryN` monitor and has consistent semantic for `every_n` and `first_n`. @@ -455,9 +471,12 @@ class PrintTensor(EveryN): class LoggingTrainable(EveryN): """Writes trainable variable values into log every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Write the tensors in trainable variables `every_n` steps, starting with the `first_n`th step. - """ def __init__(self, scope=None, every_n=100, first_n=1): @@ -493,7 +512,12 @@ class LoggingTrainable(EveryN): class SummarySaver(EveryN): - """Saves summaries every N steps.""" + """Saves summaries every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, summary_op, @@ -554,6 +578,10 @@ class SummarySaver(EveryN): class ValidationMonitor(EveryN): """Runs evaluation of a given estimator, at most every N steps. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note that the evaluation is done based on the saved checkpoint, which will usually be older than the current step. @@ -573,7 +601,8 @@ class ValidationMonitor(EveryN): early_stopping_rounds=None, early_stopping_metric="loss", early_stopping_metric_minimize=True, - name=None): + name=None, + check_interval_secs=5): """Initializes a ValidationMonitor. Args: @@ -600,6 +629,9 @@ class ValidationMonitor(EveryN): loss metrics like mean squared error, and False for performance metrics like accuracy. name: See `BaseEstimator.evaluate`. + check_interval_secs: Only check for new checkpoint if at least + `check_interval_secs` have passed. Ignore if None. Default is 5 secs. + Raises: ValueError: If both x and input_fn are provided. @@ -626,6 +658,8 @@ class ValidationMonitor(EveryN): self._early_stopped = False self._latest_path = None self._latest_path_step = None + self._last_checkpoint_check_time = None + self._check_interval_secs = check_interval_secs @property def early_stopped(self): @@ -690,6 +724,16 @@ class ValidationMonitor(EveryN): # that's what is being evaluated. if self._estimator is None: raise ValueError("Missing call to set_estimator.") + current_time = time.time() + if (self._check_interval_secs is not None and + self._last_checkpoint_check_time is not None and + current_time - self._last_checkpoint_check_time <= + self._check_interval_secs): + logging.debug( + "Skipping evaluation since less than %d seconds have passed since " + "last check for a new checkpoint.", self._check_interval_secs) + return False + self._last_checkpoint_check_time = current_time # Check that we are not running evaluation on the same checkpoint. latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir) if latest_path is None: @@ -740,6 +784,10 @@ class ValidationMonitor(EveryN): class CaptureVariable(EveryN): """Captures a variable's values into a collection. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + This monitor is useful for unit testing. You should exercise caution when using this monitor in production, since it never discards values. @@ -778,6 +826,7 @@ class CaptureVariable(EveryN): self._var_values[step] = _extract_output(outputs, self._var_name) +@deprecation.deprecated(None, "Use tf.train.MonitoredTrainingSession.") def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, @@ -812,6 +861,10 @@ def get_default_monitors(loss_op=None, class GraphDump(BaseMonitor): """Dumps almost all tensors in the graph at every step. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Note, this is very expensive, prefer `PrintTensor` in production. """ @@ -901,7 +954,12 @@ class GraphDump(BaseMonitor): class ExportMonitor(EveryN): - """Monitor that exports Estimator every N steps.""" + """Monitor that exports Estimator every N steps. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ @deprecation.deprecated("2017-03-25", "ExportMonitor is deprecated. Please pass an " @@ -1024,7 +1082,12 @@ class ExportMonitor(EveryN): class CheckpointSaver(BaseMonitor): - """Saves checkpoints every N steps or N seconds.""" + """Saves checkpoints every N steps or N seconds. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, checkpoint_dir, @@ -1109,7 +1172,12 @@ class CheckpointSaver(BaseMonitor): class StepCounter(EveryN): - """Steps per second monitor.""" + """Steps per second monitor. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): super(StepCounter, self).__init__(every_n_steps=every_n_steps) @@ -1149,6 +1217,10 @@ class NanLossDuringTrainingError(RuntimeError): class NanLoss(EveryN): """NaN Loss monitor. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Monitors loss and stops training if loss is NaN. Can either fail with exception or just stop training. """ diff --git a/tensorflow/contrib/learn/python/learn/monitors_test.py b/tensorflow/contrib/learn/python/learn/monitors_test.py index b2b24776c60183113a5f936dd276ff312d6d0079..5c34d0ddb01f3bcdc407e6926e7c5b73be1863b4 100644 --- a/tensorflow/contrib/learn/python/learn/monitors_test.py +++ b/tensorflow/contrib/learn/python/learn/monitors_test.py @@ -385,7 +385,11 @@ class MonitorsTest(test.TestCase): estimator.evaluate.return_value = validation_outputs monitor = learn.monitors.ValidationMonitor( - x=constant_op.constant(2.0), every_n_steps=0, early_stopping_rounds=2) + x=constant_op.constant(2.0), + every_n_steps=0, + early_stopping_rounds=2, + check_interval_secs=None) + self._assert_validation_monitor(monitor) monitor.set_estimator(estimator) with ops.Graph().as_default() as g, self.test_session(g): diff --git a/tensorflow/contrib/learn/python/learn/ops/__init__.py b/tensorflow/contrib/learn/python/learn/ops/__init__.py index 33962e34cc685ce2c830a7bbfd1b5c626bcd8b31..efb1f47cf5bb2dcd0fb37b7b85cd8f170d56e4d1 100644 --- a/tensorflow/contrib/learn/python/learn/ops/__init__.py +++ b/tensorflow/contrib/learn/python/learn/ops/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Various TensorFlow Ops.""" +"""Various TensorFlow Ops (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py index fa3b7323e343371e986b763d30a8a44620894549..b3b067b8e1a4eb9f644e8e55587b3405d91a0189 100644 --- a/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/embeddings_ops.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops to work with embeddings. +"""TensorFlow Ops to work with embeddings (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Note: categorical variables are handled via embeddings in many cases. For example, in case of words. diff --git a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py index b040ab3bb6c516158589a8e30d56fff1f7728951..92976d1539c7ddc226b81f903beee82b798ec8db 100644 --- a/tensorflow/contrib/learn/python/learn/ops/losses_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/losses_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for loss computation.""" +"""TensorFlow Ops for loss computation (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py index 45727faab4362abeab18f77861353eb53976023a..aa37cb4a76e2a6157bf077d327248353bd516472 100644 --- a/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py +++ b/tensorflow/contrib/learn/python/learn/ops/seq2seq_ops.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Ops for Sequence to Sequence models.""" +"""TensorFlow Ops for Sequence to Sequence models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -26,8 +31,10 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.util.deprecation import deprecated +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): """Returns predictions and loss for sequence of predictions. @@ -57,6 +64,7 @@ def sequence_classifier(decoding, labels, sampling_decoding=None, name=None): return array_ops.stack(predictions, axis=1), loss +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): """Processes inputs for Sequence to Sequence models. @@ -87,6 +95,7 @@ def seq2seq_inputs(x, y, input_length, output_length, sentinel=None, name=None): return in_x, in_y, out_y +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): """RNN Decoder that creates training and sampling sub-graphs. @@ -123,6 +132,7 @@ def rnn_decoder(decoder_inputs, initial_state, cell, scope=None): return outputs, states, sampling_outputs, sampling_states +@deprecated(None, 'Please use tf.nn/tf.layers directly.') def rnn_seq2seq(encoder_inputs, decoder_inputs, encoder_cell, diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py index 7bcc177d4ea0ab57f092d68888a72de2b2fd5edc..e8c6e1acf80f0791421bee59aff30e67bccb44b2 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Preprocessing tools useful for building models.""" +"""Preprocessing tools useful for building models (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py index 154739d497ec1029026eaca1e93b37cd225f1050..faba3b2025e8abb51d1989c3fafbd5e711d6559b 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements preprocessing transformers for categorical variables.""" +"""Implements preprocessing transformers for categorical variables (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -22,6 +27,8 @@ from __future__ import print_function import math import numpy as np +from tensorflow.python.util.deprecation import deprecated + # pylint: disable=g-bad-import-order from . import categorical_vocabulary from ..learn_io.data_feeder import setup_processor_data_feeder @@ -31,10 +38,16 @@ from ..learn_io.data_feeder import setup_processor_data_feeder class CategoricalProcessor(object): """Maps documents to sequences of word ids. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + As a common convention, Nan values are handled as unknown tokens. Both float('nan') and np.nan are accepted. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data for sequence ' + 'processing.') def __init__(self, min_frequency=0, share=False, vocabularies=None): """Initializes a CategoricalProcessor instance. diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py index 5709955c49fba50ca4a299a443a2902bbd9c6b23..3ac370a6ab4423846e810900514445ad5269b680 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/categorical_vocabulary.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -"""Categorical vocabulary classes to map categories to indexes. +"""Categorical vocabulary classes to map categories to indexes (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Can be used for categorical variables, sparse variables and words. """ @@ -25,14 +29,21 @@ from __future__ import print_function import collections import six +from tensorflow.python.util.deprecation import deprecated + class CategoricalVocabulary(object): """Categorical variables vocabulary class. + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + Accumulates and provides mapping from classes to indexes. Can be easily used for words. """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, unknown_token="", support_reverse=True): self._unknown_token = unknown_token self._mapping = {unknown_token: 0} diff --git a/tensorflow/contrib/learn/python/learn/preprocessing/text.py b/tensorflow/contrib/learn/python/learn/preprocessing/text.py index 3af2074c2a46f0258c04111fff0235ba8309625e..f2b6776be7789a9433bfe41eb9354b74347059ec 100644 --- a/tensorflow/contrib/learn/python/learn/preprocessing/text.py +++ b/tensorflow/contrib/learn/python/learn/preprocessing/text.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""Implements a number of text preprocessing utilities.""" +"""Implements a number of text preprocessing utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -24,6 +29,7 @@ import numpy as np import six from tensorflow.python.platform import gfile +from tensorflow.python.util.deprecation import deprecated from .categorical_vocabulary import CategoricalVocabulary # pylint: disable=g-bad-import-order @@ -38,6 +44,7 @@ TOKENIZER_RE = re.compile(r"[A-Z]{2,}(?![a-z])|[A-Z][a-z]+(?=[A-Z])|[\'\w\-]+", re.UNICODE) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') def tokenizer(iterator): """Tokenizer generator. @@ -51,9 +58,16 @@ def tokenizer(iterator): yield TOKENIZER_RE.findall(value) +@deprecated(None, 'Please use tensorflow/transform or tf.data.') class ByteProcessor(object): - """Maps documents into sequence of ids for bytes.""" + """Maps documents into sequence of ids for bytes. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length): self.max_document_length = max_document_length @@ -108,8 +122,14 @@ class ByteProcessor(object): class VocabularyProcessor(object): - """Maps documents to sequences of word ids.""" + """Maps documents to sequences of word ids. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Please use tensorflow/transform or tf.data.') def __init__(self, max_document_length, min_frequency=0, diff --git a/tensorflow/contrib/learn/python/learn/session_run_hook.py b/tensorflow/contrib/learn/python/learn/session_run_hook.py index a8ba2be97206f2b974d256ad2c62c21a4e3e55d8..87edc9b720bdb3edcd5f2dcd1662d14da53c51cf 100644 --- a/tensorflow/contrib/learn/python/learn/session_run_hook.py +++ b/tensorflow/contrib/learn/python/learn/session_run_hook.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""This file is deprecated. Use tensorflow.python.training.session_run_hook.""" +"""This file is deprecated. Use `tensorflow.python.training.session_run_hook`. + +See [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py index 919d415c302b8ec17202aad34ff0bee69bfee2c7..d663cf5fb79c428b0e70d66b0f1305f0559a05c9 100644 --- a/tensorflow/contrib/learn/python/learn/summary_writer_cache.py +++ b/tensorflow/contrib/learn/python/learn/summary_writer_cache.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Wrapper for a Session-like object that handles threads and recovery. +"""Wrapper for a Session-like object that handles threads and recovery (deprecated). + +These are deprecated aliases for classes and functions in `tf.train`. Please use +those directly. Based on an original design of Illia Polosukhin. """ diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 429b6040be21d8cbe1f2bba58090366552fdfbe7..a1a3f20dcd8cb5ff7baa559ac41d5e5c40780511 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`Trainable` interface.""" +"""`Trainable` interface (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division @@ -23,6 +28,8 @@ import abc class Trainable(object): """Interface for objects that are trainable by, e.g., `Experiment`. + + THIS CLASS IS DEPRECATED. """ __metaclass__ = abc.ABCMeta diff --git a/tensorflow/contrib/learn/python/learn/utils/__init__.py b/tensorflow/contrib/learn/python/learn/utils/__init__.py index 48978d0ac34cec2b18e6794dcf3b260bc3b683c4..66d8dc6fd43b383919a16515bc96be492a253bf6 100644 --- a/tensorflow/contrib/learn/python/learn/utils/__init__.py +++ b/tensorflow/contrib/learn/python/learn/utils/__init__.py @@ -13,7 +13,12 @@ # limitations under the License. # ============================================================================== -"""TensorFlow Learn Utils.""" +"""TensorFlow Learn Utils (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index cb34cb1d26b6812c7f3f39e9f965615de5a8ef07..3eacac7a3d3dcff4d39025fdee88e16e385b1b84 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -13,14 +13,18 @@ # limitations under the License. # ============================================================================== -"""Export utilities.""" +"""Export utilities (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -32,6 +36,7 @@ from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver +from tensorflow.python.training import training_util @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py index 226915987a4934626066b12810f579ae675107b2..916aecbea88b10bbef316ffb89d4c4d89667cb29 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -13,7 +13,11 @@ # limitations under the License. # ============================================================================== -r"""System for specifying garbage collection (GC) of path based data. +r"""System for specifying garbage collection (GC) of path based data (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. This framework allows for GC of data specified by path names, for example files on disk. gc.Path objects each represent a single item stored at a path and may @@ -73,10 +77,12 @@ import os from tensorflow.python.platform import gfile from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated Path = collections.namedtuple('Path', 'path export_version') +@deprecated(None, 'Please implement your own file management or use Saver.') def largest_export_versions(n): """Creates a filter that keeps the largest n export versions. @@ -97,6 +103,7 @@ def largest_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def one_of_every_n_export_versions(n): """Creates a filter that keeps one of every n export versions. @@ -128,6 +135,7 @@ def one_of_every_n_export_versions(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def mod_export_version(n): """Creates a filter that keeps every export that is a multiple of n. @@ -146,6 +154,7 @@ def mod_export_version(n): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def union(lf, rf): """Creates a filter that keeps the union of two filters. @@ -163,6 +172,7 @@ def union(lf, rf): return keep +@deprecated(None, 'Please implement your own file management or use Saver.') def negation(f): """Negate a filter. @@ -179,6 +189,7 @@ def negation(f): return keep +@deprecated(None, 'Please implement your own file name management.') def get_paths(base_dir, parser): """Gets a list of Paths in a given directory. diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py index b2521933e524e7ec24d73d4b5171f33e507dd88c..b92eb9fea8b7ccea56c781df74dcfa1cc5508e48 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities for creating input_fns. +"""Utilities for creating input_fns (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Contents of this file are moved to tensorflow/python/estimator/export.py. InputFnOps is renamed to ServingInputReceiver. @@ -32,13 +36,17 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import parsing_ops +from tensorflow.python.util.deprecation import deprecated class InputFnOps(collections.namedtuple('InputFnOps', ['features', 'labels', 'default_inputs'])): - """A return type for an input_fn. + """A return type for an input_fn (deprecated). + + THIS CLASS IS DEPRECATED. Please use tf.estimator.export.ServingInputReceiver + instead. This return type is currently only supported for serving input_fn. Training and eval input_fn should return a `(features, labels)` tuple. @@ -56,6 +64,8 @@ class InputFnOps(collections.namedtuple('InputFnOps', """ +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_parsing_serving_input_receiver_fn.') def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): """Build an input_fn appropriate for serving, expecting fed tf.Examples. @@ -84,6 +94,8 @@ def build_parsing_serving_input_fn(feature_spec, default_batch_size=None): return input_fn +@deprecated(None, 'Please use ' + 'tf.estimator.export.build_raw_serving_input_receiver_fn.') def build_default_serving_input_fn(features, default_batch_size=None): """Build an input_fn appropriate for serving, expecting feature Tensors. diff --git a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py index 6a63fb545a56e6040b0b0c3bbb6a17cd96925895..6dbaa15f8391b0044be8e30ca191753beb88db93 100644 --- a/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py +++ b/tensorflow/contrib/learn/python/learn/utils/inspect_checkpoint.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A simple script for inspect checkpoint files.""" +"""A simple script for inspect checkpoint files (deprecated).""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 1593380007b2799fb1d17e92408ab19a7b47fe1e..213619a1877d898dc7c55f6b8c340df5c1afbf27 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities supporting export to SavedModel. +"""Utilities supporting export to SavedModel (deprecated). + +This module and all its submodules are deprecated. See +[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) +for migration instructions. Some contents of this file are moved to tensorflow/python/estimator/export.py: @@ -52,8 +56,9 @@ from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.summary import summary_iterator from tensorflow.python.training import saver - from tensorflow.python.util import compat +from tensorflow.python.util.deprecation import deprecated + # A key for use in the input_alternatives dict indicating the default input. # This is the input that will be expected when a serving request does not @@ -77,6 +82,7 @@ FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' _FALLBACK_DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_standardized_signature_def(input_tensors, output_tensors, problem_type): """Build a SignatureDef using problem type and input and output Tensors. @@ -156,6 +162,7 @@ def _is_regression_problem(problem_type, input_tensors, output_tensors): len(input_tensors) == 1 and len(output_tensors) == 1) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_input_alternatives(input_ops): """Obtain all input alternatives using the input_fn output and heuristics.""" input_alternatives = {} @@ -181,6 +188,7 @@ def get_input_alternatives(input_ops): return input_alternatives, features +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): """Obtain all output alternatives using the model_fn output and heuristics. @@ -246,6 +254,7 @@ def get_output_alternatives(model_fn_ops, default_output_alternative_key=None): sorted(output_alternatives.keys()))) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def build_all_signature_defs(input_alternatives, output_alternatives, actual_default_output_alternative_key): """Build `SignatureDef`s from all pairs of input and output alternatives.""" @@ -279,6 +288,7 @@ def build_all_signature_defs(input_alternatives, output_alternatives, MAX_DIRECTORY_CREATION_ATTEMPTS = 10 +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_timestamped_export_dir(export_dir_base): """Builds a path to a new subdirectory within the base directory. @@ -317,6 +327,7 @@ def get_timestamped_export_dir(export_dir_base): '{} attempts.'.format(MAX_DIRECTORY_CREATION_ATTEMPTS)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_temp_export_dir(timestamped_export_dir): """Builds a directory name based on the argument but starting with 'temp-'. @@ -344,6 +355,7 @@ def _export_version_parser(path): return path._replace(export_version=int(filename)) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def get_most_recent_export(export_dir_base): """Locate the most recent SavedModel export in a directory of many exports. @@ -363,6 +375,7 @@ def get_most_recent_export(export_dir_base): return next(iter(results or []), None) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def garbage_collect_exports(export_dir_base, exports_to_keep): """Deletes older exports, retaining only a given number of the most recent. @@ -387,6 +400,7 @@ def garbage_collect_exports(export_dir_base, exports_to_keep): logging.warn('Can not delete %s recursively: %s', p.path, e) +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, @@ -469,6 +483,8 @@ def make_export_strategy(serving_input_fn, return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) +@deprecated(None, + 'Use tf.estimator.export.build_parsing_serving_input_receiver_fn') def make_parsing_export_strategy(feature_columns, default_output_alternative_key=None, assets_extra=None, @@ -555,8 +571,14 @@ def _default_compare_fn(curr_best_eval_result, cand_eval_result): class BestModelSelector(object): - """A helper that keeps track of export selection candidates.""" + """A helper that keeps track of export selection candidates. + + THIS CLASS IS DEPRECATED. See + [contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) + for general migration instructions. + """ + @deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def __init__(self, event_file_pattern=None, compare_fn=None): """Constructor of this class. @@ -622,6 +644,7 @@ class BestModelSelector(object): return best_eval_result +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def make_best_model_export_strategy( serving_input_fn, exports_to_keep=1, @@ -707,6 +730,7 @@ def make_best_model_export_strategy( # TODO(b/67013778): Revisit this approach when corresponding changes to # TF Core are finalized. +@deprecated(None, 'Switch to tf.estimator.Exporter and associated utilities.') def extend_export_strategy(base_export_strategy, post_export_fn, post_export_name=None): diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 3520a4eaf0e6647bd1357a5bc34140e8f87e1215..44c4a7e2ca8d019ca602c7f2b492cd1e70b17561 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -10,6 +10,7 @@ exports_files(["LICENSE"]) exports_files(glob([ "testdata/*.bin", + "testdata/*.pb", "models/testdata/*", ])) diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index 3e55d2a496c1d83ec0501df27deee4e19a5012a7..00e93d2c4f3ab27057b855fba6fccf2ec8d7a1c1 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -6,7 +6,7 @@ TensorFlow Lite uses many techniques for achieving low latency like optimizing t ![image](g3doc/TFLite-Architecture.jpg) # Getting Started with an Android Demo App -This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. +This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using either a quantized Mobilenet model or a floating point Inception-v3 model. A device running Android 5.0 ( API 21) or higher is required to run the demo. There are 3 ways to get the demo app to your device - Download the prebuilt binary or @@ -29,9 +29,16 @@ The simplest way to compile the demo app, and try out changes to the project cod - Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings). - Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project. - Click through installing all the Gradle extensions it requests. - - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) - - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: - `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + - Either + - Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip) + - unzip and copy mobilenet_quant_v1_224.tflite to the assets directory: + `tensorflow/contrib/lite/java/demo/app/src/main/assets/` + - Or download the floating point Inception-v3 model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip) + - unzip and copy inceptionv3_non_slim_2015.tflite to the assets directory + - change the chosen classifier in [Camera2BasicFragment.java](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java) from + `classifier = new ImageClassifierQuantizedMobileNet(getActivity());` + to + `classifier = new ImageClassifierFloatInception(getActivity());` - Build and run the demo app ## Building TensorFlow Lite and the demo app from source @@ -84,7 +91,7 @@ Currently, we only support building the Android demo app within a Python 2 environment (due to a Bazel bug). ### More about the demo -The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. +The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used (229 * 229 for Inception-v3). The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch. 224 * 224 (299 * 299) is the width and height of the image. 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. Both models have 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The model file must be downloaded and bundled within the assets directory of the app. # iOS Demo App diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc index 87b17c338e7afc33d32dd9688cc0825ac319dd19..8e47e2375e2e306c345a2b6caa2411abd9b3ceb0 100644 --- a/tensorflow/contrib/lite/arena_planner.cc +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -128,6 +128,11 @@ TfLiteStatus ArenaPlanner::PlanAllocations() { } TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { + // Grow the size of `allocs_` if necessary. This allows allocating temporary + // tensors in op's `prepare` function. + TF_LITE_ENSURE(context_, graph_info_->num_tensors() >= allocs_.size()); + allocs_.resize(graph_info_->num_tensors()); + TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node)); TF_LITE_ENSURE_STATUS(Commit()); diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 5dbeadd16582ec586adab100b8a46e10182bd5ee..5fc8954743e5b3b458e5c2004f4378cbad6056c0 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -195,6 +195,10 @@ typedef struct { bool keep_dims; } TfLiteMeanParams; +typedef struct { + int num_splits; +} TfLiteSplitParams; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..88cdf1d46312f1e610825f23f3d8d357b0762bac --- /dev/null +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -0,0 +1,86 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ + +// DO NOT EDIT MANUALLY: This file is automatically generated by +// `schema_builtin_ops_header_generator.py`. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// The enum for builtin operators. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real biultin +// ops. +typedef enum { + kTfLiteBuiltinAdd = 0, + kTfLiteBuiltinAveragePool2d = 1, + kTfLiteBuiltinConcatenation = 2, + kTfLiteBuiltinConv2d = 3, + kTfLiteBuiltinDepthwiseConv2d = 4, + kTfLiteBuiltinEmbeddingLookup = 7, + kTfLiteBuiltinFullyConnected = 9, + kTfLiteBuiltinHashtableLookup = 10, + kTfLiteBuiltinL2Normalization = 11, + kTfLiteBuiltinL2Pool2d = 12, + kTfLiteBuiltinLocalResponseNormalization = 13, + kTfLiteBuiltinLogistic = 14, + kTfLiteBuiltinLshProjection = 15, + kTfLiteBuiltinLstm = 16, + kTfLiteBuiltinMaxPool2d = 17, + kTfLiteBuiltinMul = 18, + kTfLiteBuiltinRelu = 19, + kTfLiteBuiltinReluN1To1 = 20, + kTfLiteBuiltinRelu6 = 21, + kTfLiteBuiltinReshape = 22, + kTfLiteBuiltinResizeBilinear = 23, + kTfLiteBuiltinRnn = 24, + kTfLiteBuiltinSoftmax = 25, + kTfLiteBuiltinSpaceToDepth = 26, + kTfLiteBuiltinSvdf = 27, + kTfLiteBuiltinTanh = 28, + kTfLiteBuiltinConcatEmbeddings = 29, + kTfLiteBuiltinSkipGram = 30, + kTfLiteBuiltinCall = 31, + kTfLiteBuiltinCustom = 32, + kTfLiteBuiltinEmbeddingLookupSparse = 33, + kTfLiteBuiltinPad = 34, + kTfLiteBuiltinUnidirectionalSequenceRnn = 35, + kTfLiteBuiltinGather = 36, + kTfLiteBuiltinBatchToSpaceNd = 37, + kTfLiteBuiltinSpaceToBatchNd = 38, + kTfLiteBuiltinTranspose = 39, + kTfLiteBuiltinMean = 40, + kTfLiteBuiltinSub = 41, + kTfLiteBuiltinDiv = 42, + kTfLiteBuiltinSqueeze = 43, + kTfLiteBuiltinUnidirectionalSequenceLstm = 44, + kTfLiteBuiltinStridedSlice = 45, + kTfLiteBuiltinBidirectionalSequenceRnn = 46, + kTfLiteBuiltinExp = 47, + kTfLiteBuiltinTopkV2 = 48, + kTfLiteBuiltinSplit = 49, + kTfLiteBuiltinLogSoftmax = 50, + kTfLiteBuiltinDelegate = 51, + kTfLiteBuiltinBidirectionalSequenceLstm = 52, +} TfLiteBuiltinOperator; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +} diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index b0c4d3431f9a67bc87d51ada91ed73f1661023a2..ed7f4515fa4437d61a37be93616c28a046295c5a 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -258,7 +258,7 @@ typedef struct TfLiteContext { TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, TfLiteIntArray** execution_plan); - // An tensor of tensors in the interpreter context (of length `tensors_size`) + // An array of tensors in the interpreter context (of length `tensors_size`) TfLiteTensor* tensors; // opaque full context ptr (an opaque c++ data structure) @@ -283,7 +283,8 @@ typedef struct TfLiteContext { TfLiteNode** node, TfLiteRegistration** registration); - // Replace ops with delegate. + // Replace ops with one or more stub delegate operations. This function + // does not take ownership of `nodes_to_replace`. TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( struct TfLiteContext*, TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace); diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h index f0d81cf7a4710d94bf62b81ff541c26c3aaab44a..2a64c1de725b601e9b6e9325d9faacb37df0e626 100644 --- a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -64,8 +64,11 @@ void resize(T* out, uint8_t* in, int image_height, int image_width, ops::builtin::BuiltinOpResolver resolver; TfLiteRegistration* resize_op = resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); - interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, nullptr, - resize_op, nullptr); + auto* params = reinterpret_cast( + malloc(sizeof(TfLiteResizeBilinearParams))); + params->align_corners = false; + interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params, resize_op, + nullptr); interpreter->AllocateTensors(); diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 8e5e694a5cbe7f908572114db33c8257db6151f0..b1bbb7c67013acfb575cc1e9f9390ba191cbd08e 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,4 +1,4 @@ -# TensorFlow Compatibility Guide +# TensorFlow Lite & TensorFlow Compatibility Guide TensorFlow Lite supports a number of TensorFlow operations used in common inference models. As they are processed by the TensorFlow Lite Optimizing diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 6dea4e59163dea502c5511198e1d058b280fec97..0f5e17f0de0d828771e1fdbeac0e172f2ed9159c 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -25,13 +25,7 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/gemm_support.h" #include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" - -namespace { - -// std::vector preallocation tuning. -constexpr const int kSlotsToReserve = 128; - -} // namespace +#include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { @@ -84,8 +78,8 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.GetExecutionPlan = nullptr; // Reserve some space for the tensors to avoid excessive resizing. - tensors_.reserve(kSlotsToReserve); - nodes_and_registration_.reserve(kSlotsToReserve); + tensors_.reserve(kTensorsReservedCapacity); + nodes_and_registration_.reserve(kTensorsReservedCapacity); next_execution_plan_index_to_prepare_ = 0; UseNNAPI(false); } @@ -115,6 +109,9 @@ TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) { + // Annotate the registration as DELEGATE op. + registration.builtin_code = BuiltinOperator_DELEGATE; + // Analyze the graph to find all independent subgraphs that are either // fully not-this-delegate or this-delegate computation. InterpreterInfo info(this); @@ -166,7 +163,7 @@ TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) { static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]), "TfLiteIntArray and execution_plan do not contain same type."); memcpy(plan_cache_->data, execution_plan_.data(), - sizeof(plan_cache_->data[0])); + sizeof(plan_cache_->data[0]) * execution_plan_.size()); return kTfLiteOk; } @@ -298,7 +295,20 @@ TfLiteStatus Interpreter::AddNodeWithParameters( OpInit(*registration, reinterpret_cast(builtin_data_deleter.get()), 0); } + node.builtin_data = builtin_data_deleter.release(); + // TODO(ycling): Filling `custom_initial_data` and `custom_initial_data_size` + // properly for nodes generated by ReplaceSubgraphsWithDelegateKernels. + if (registration->builtin_code == BuiltinOperator_CUSTOM) { + // When it's a CUSTOM op, the `custom_options` field in the Flatbuffer + // `Operator` table is passed in. + node.custom_initial_data = init_data; + node.custom_initial_data_size = init_data_size; + } else { + node.custom_initial_data = nullptr; + node.custom_initial_data_size = 0; + } + node_and_reg.second = *registration; execution_plan_.push_back(new_node_index); return kTfLiteOk; @@ -336,6 +346,7 @@ TfLiteStatus Interpreter::PrepareOpsStartingAt( TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; + EnsureTensorsVectorCapacity(); if (OpPrepare(registration, &node) == kTfLiteError) { return kTfLiteError; } @@ -413,6 +424,7 @@ TfLiteStatus Interpreter::Invoke() { TfLiteNode& node = nodes_and_registration_[node_index].first; const TfLiteRegistration& registration = nodes_and_registration_[node_index].second; + EnsureTensorsVectorCapacity(); if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index bab56a9d72f8992a9d8af23f92133c7c918fd46d..04c19644a026bff0f3693f7b05832393bafd0324 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/memory_planner.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" namespace tflite { @@ -258,6 +259,20 @@ class Interpreter { // contain new nodes that replace 1 more nodes. TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); + // WARNING: This is a deprecated interface and will be removed as soon as + // possible. Please do not use it. + // TODO(impjdi): Remove this interface after resolving dependencies. + void set_model(const Model* model) { model_ = const_cast(model); } + Model* model() const { return model_; } + + // The default capacity of `tensors_` vector. + static constexpr int kTensorsReservedCapacity = 128; + // The capacity headroom of `tensors_` vector before calling ops' + // `prepare` and `invoke` function. In these functions, it's guaranteed + // allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate + // pointers to existing tensors. + static constexpr int kTensorsCapacityHeadroom = 16; + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -370,6 +385,18 @@ class Interpreter { static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, TfLiteIntArray** execution_plan); + // Ensures that `tensors_` has at least `kTensorsCapacityHeadroom` extra + // capacity. Calling this function may invalidate existing pointers to + // tensors. After calling this function, adding `kTensorsCapacityHeadroom` + // more tensors won't invalidate the pointer to existing tensors. + void EnsureTensorsVectorCapacity() { + const int required_capacity = tensors_size() + kTensorsCapacityHeadroom; + if (required_capacity > tensors_.capacity()) { + tensors_.reserve(required_capacity); + context_.tensors = tensors_.data(); + } + } + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -425,6 +452,11 @@ class Interpreter { std::unique_ptr nnapi_delegate_; std::unique_ptr memory_planner_; + + // WARNING: This is a deprecated interface and will be removed as soon as + // possible. Please do not use it. + // TODO(impjdi): Remove this interface after resolving dependencies. + Model* model_ = nullptr; }; } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index 4b309748f7216dbd20ad9e0f3a9c4b4b72e40216..2e6727b32361ab771354a3954e5e4d8f9fa833a5 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -561,6 +561,46 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { ASSERT_EQ(reporter.calls, 1); } +TEST(InterpreterTensorsCapacityTest, TestWithinHeadroom) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), + kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + registration.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* first_tensor = context->tensors; + + int new_tensor_index; + context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom, + &new_tensor_index); + EXPECT_EQ(first_tensor, context->tensors); + return kTfLiteOk; + }; + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); +} + +TEST(InterpreterTensorsCapacityTest, TestExceedHeadroom) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(Interpreter::kTensorsReservedCapacity), + kTfLiteOk); + TfLiteRegistration registration = {nullptr, nullptr, nullptr, nullptr}; + registration.prepare = [](TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* first_tensor = context->tensors; + + int new_tensor_index; + context->AddTensors(context, Interpreter::kTensorsCapacityHeadroom + 1, + &new_tensor_index); + EXPECT_NE(first_tensor, context->tensors); + return kTfLiteOk; + }; + ASSERT_EQ(interpreter.AddNodeWithParameters({0}, {1}, nullptr, 0, nullptr, + ®istration), + kTfLiteOk); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); +} + // Test fixture that allows playing with execution plans. It creates a two // node graph that can be executed in either [0,1] order or [1,0] order. // The CopyOp records when it is invoked in the class member run_order_ @@ -773,6 +813,8 @@ class TestDelegate : public ::testing::Test { for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { int node_index = execution_plan->data[exec_index]; + // Check that we are an identity map to start. + TFLITE_CHECK_EQ(exec_index, node_index); TfLiteNode* node; TfLiteRegistration* reg; context->GetNodeAndRegistration(context, node_index, &node, ®); diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt new file mode 100644 index 0000000000000000000000000000000000000000..572eccf90087c1c19874e40b950c1610f59cc9c2 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_imagenet_slim.txt @@ -0,0 +1,1001 @@ +dummy +tench +goldfish +great white shark +tiger shark +hammerhead +electric ray +stingray +cock +hen +ostrich +brambling +goldfinch +house finch +junco +indigo bunting +robin +bulbul +jay +magpie +chickadee +water ouzel +kite +bald eagle +vulture +great grey owl +European fire salamander +common newt +eft +spotted salamander +axolotl +bullfrog +tree frog +tailed frog +loggerhead +leatherback turtle +mud turtle +terrapin +box turtle +banded gecko +common iguana +American chameleon +whiptail +agama +frilled lizard +alligator lizard +Gila monster +green lizard +African chameleon +Komodo dragon +African crocodile +American alligator +triceratops +thunder snake +ringneck snake +hognose snake +green snake +king snake +garter snake +water snake +vine snake +night snake +boa constrictor +rock python +Indian cobra +green mamba +sea snake +horned viper +diamondback +sidewinder +trilobite +harvestman +scorpion +black and gold garden spider +barn spider +garden spider +black widow +tarantula +wolf spider +tick +centipede +black grouse +ptarmigan +ruffed grouse +prairie chicken +peacock +quail +partridge +African grey +macaw +sulphur-crested cockatoo +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser +goose +black swan +tusker +echidna +platypus +wallaby +koala +wombat +jellyfish +sea anemone +brain coral +flatworm +nematode +conch +snail +slug +sea slug +chiton +chambered nautilus +Dungeness crab +rock crab +fiddler crab +king crab +American lobster +spiny lobster +crayfish +hermit crab +isopod +white stork +black stork +spoonbill +flamingo +little blue heron +American egret +bittern +crane +limpkin +European gallinule +American coot +bustard +ruddy turnstone +red-backed sandpiper +redshank +dowitcher +oystercatcher +pelican +king penguin +albatross +grey whale +killer whale +dugong +sea lion +Chihuahua +Japanese spaniel +Maltese dog +Pekinese +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound +basset +beagle +bloodhound +bluetick +black-and-tan coonhound +Walker hound +English foxhound +redbone +borzoi +Irish wolfhound +Italian greyhound +whippet +Ibizan hound +Norwegian elkhound +otterhound +Saluki +Scottish deerhound +Weimaraner +Staffordshire bullterrier +American Staffordshire terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier +Airedale +cairn +Australian terrier +Dandie Dinmont +Boston bull +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier +Tibetan terrier +silky terrier +soft-coated wheaten terrier +West Highland white terrier +Lhasa +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla +English setter +Irish setter +Gordon setter +Brittany spaniel +clumber +English springer +Welsh springer spaniel +cocker spaniel +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog +Shetland sheepdog +collie +Border collie +Bouvier des Flandres +Rottweiler +German shepherd +Doberman +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard +Eskimo dog +malamute +Siberian husky +dalmatian +affenpinscher +basenji +pug +Leonberg +Newfoundland +Great Pyrenees +Samoyed +Pomeranian +chow +keeshond +Brabancon griffon +Pembroke +Cardigan +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf +white wolf +red wolf +coyote +dingo +dhole +African hunting dog +hyena +red fox +kit fox +Arctic fox +grey fox +tabby +tiger cat +Persian cat +Siamese cat +Egyptian cat +cougar +lynx +leopard +snow leopard +jaguar +lion +tiger +cheetah +brown bear +American black bear +ice bear +sloth bear +mongoose +meerkat +tiger beetle +ladybug +ground beetle +long-horned beetle +leaf beetle +dung beetle +rhinoceros beetle +weevil +fly +bee +ant +grasshopper +cricket +walking stick +cockroach +mantis +cicada +leafhopper +lacewing +dragonfly +damselfly +admiral +ringlet +monarch +cabbage butterfly +sulphur butterfly +lycaenid +starfish +sea urchin +sea cucumber +wood rabbit +hare +Angora +hamster +porcupine +fox squirrel +marmot +beaver +guinea pig +sorrel +zebra +hog +wild boar +warthog +hippopotamus +ox +water buffalo +bison +ram +bighorn +ibex +hartebeest +impala +gazelle +Arabian camel +llama +weasel +mink +polecat +black-footed ferret +otter +skunk +badger +armadillo +three-toed sloth +orangutan +gorilla +chimpanzee +gibbon +siamang +guenon +patas +baboon +macaque +langur +colobus +proboscis monkey +marmoset +capuchin +howler monkey +titi +spider monkey +squirrel monkey +Madagascar cat +indri +Indian elephant +African elephant +lesser panda +giant panda +barracouta +eel +coho +rock beauty +anemone fish +sturgeon +gar +lionfish +puffer +abacus +abaya +academic gown +accordion +acoustic guitar +aircraft carrier +airliner +airship +altar +ambulance +amphibian +analog clock +apiary +apron +ashcan +assault rifle +backpack +bakery +balance beam +balloon +ballpoint +Band Aid +banjo +bannister +barbell +barber chair +barbershop +barn +barometer +barrel +barrow +baseball +basketball +bassinet +bassoon +bathing cap +bath towel +bathtub +beach wagon +beacon +beaker +bearskin +beer bottle +beer glass +bell cote +bib +bicycle-built-for-two +bikini +binder +binoculars +birdhouse +boathouse +bobsled +bolo tie +bonnet +bookcase +bookshop +bottlecap +bow +bow tie +brass +brassiere +breakwater +breastplate +broom +bucket +buckle +bulletproof vest +bullet train +butcher shop +cab +caldron +candle +cannon +canoe +can opener +cardigan +car mirror +carousel +carpenter's kit +carton +car wheel +cash machine +cassette +cassette player +castle +catamaran +CD player +cello +cellular telephone +chain +chainlink fence +chain mail +chain saw +chest +chiffonier +chime +china cabinet +Christmas stocking +church +cinema +cleaver +cliff dwelling +cloak +clog +cocktail shaker +coffee mug +coffeepot +coil +combination lock +computer keyboard +confectionery +container ship +convertible +corkscrew +cornet +cowboy boot +cowboy hat +cradle +crane +crash helmet +crate +crib +Crock Pot +croquet ball +crutch +cuirass +dam +desk +desktop computer +dial telephone +diaper +digital clock +digital watch +dining table +dishrag +dishwasher +disk brake +dock +dogsled +dome +doormat +drilling platform +drum +drumstick +dumbbell +Dutch oven +electric fan +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa +file +fireboat +fire engine +fire screen +flagpole +flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn +frying pan +fur coat +garbage truck +gasmask +gas pump +goblet +go-kart +golf ball +golfcart +gondola +gong +gown +grand piano +greenhouse +grille +grocery store +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower +hand-held computer +handkerchief +hard disc +harmonica +harp +harvester +hatchet +holster +home theater +honeycomb +hook +hoopskirt +horizontal bar +horse cart +hourglass +iPod +iron +jack-o'-lantern +jean +jeep +jersey +jigsaw puzzle +jinrikisha +joystick +kimono +knee pad +knot +lab coat +ladle +lampshade +laptop +lawn mower +lens cap +letter opener +library +lifeboat +lighter +limousine +liner +lipstick +Loafer +lotion +loudspeaker +loupe +lumbermill +magnetic compass +mailbag +mailbox +maillot +maillot +manhole cover +maraca +marimba +mask +matchstick +maypole +maze +measuring cup +medicine chest +megalith +microphone +microwave +military uniform +milk can +minibus +miniskirt +minivan +missile +mitten +mixing bowl +mobile home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter +mountain bike +mountain tent +mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook +obelisk +oboe +ocarina +odometer +oil filter +organ +oscilloscope +overskirt +oxcart +oxygen mask +packet +paddle +paddlewheel +padlock +paintbrush +pajama +palace +panpipe +paper towel +parachute +parallel bars +park bench +parking meter +passenger car +patio +pay-phone +pedestal +pencil box +pencil sharpener +perfume +Petri dish +photocopier +pick +pickelhaube +picket fence +pickup +pier +piggy bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate +pitcher +plane +planetarium +plastic bag +plate rack +plow +plunger +Polaroid camera +pole +police van +poncho +pool table +pop bottle +pot +potter's wheel +power drill +prayer rug +printer +prison +projectile +projector +puck +punching bag +purse +quill +quilt +racer +racket +radiator +radio +radio telescope +rain barrel +recreational vehicle +reel +reflex camera +refrigerator +remote control +restaurant +revolver +rifle +rocking chair +rotisserie +rubber eraser +rugby ball +rule +running shoe +safe +safety pin +saltshaker +sandal +sarong +sax +scabbard +scale +school bus +schooner +scoreboard +screen +screw +screwdriver +seat belt +sewing machine +shield +shoe shop +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule +sliding door +slot +snorkel +snowmobile +snowplow +soap dispenser +soccer ball +sock +solar dish +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web +spindle +sports car +spotlight +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch +stove +strainer +streetcar +stretcher +studio couch +stupa +submarine +suit +sundial +sunglass +sunglasses +sunscreen +suspension bridge +swab +sweatshirt +swimming trunks +swing +switch +syringe +table lamp +tank +tape player +teapot +teddy +television +tennis ball +thatch +theater curtain +thimble +thresher +throne +tile roof +toaster +tobacco shop +toilet seat +torch +totem pole +tow truck +toyshop +tractor +trailer truck +tray +trench coat +tricycle +trimaran +tripod +triumphal arch +trolleybus +trombone +tub +turnstile +typewriter keyboard +umbrella +unicycle +upright +vacuum +vase +vault +velvet +vending machine +vestment +viaduct +violin +volleyball +waffle iron +wall clock +wallet +wardrobe +warplane +washbasin +washer +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool +worm fence +wreck +yawl +yurt +web site +comic book +crossword puzzle +street sign +traffic light +book jacket +menu +plate +guacamole +consomme +hot pot +trifle +ice cream +ice lolly +French loaf +bagel +pretzel +cheeseburger +hotdog +mashed potato +head cabbage +broccoli +cauliflower +zucchini +spaghetti squash +acorn squash +butternut squash +cucumber +artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple +banana +jackfruit +custard apple +pomegranate +hay +carbonara +chocolate sauce +dough +meat loaf +pizza +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff +coral reef +geyser +lakeside +promontory +sandbar +seashore +valley +volcano +ballplayer +groom +scuba diver +rapeseed +daisy +yellow lady's slipper +corn +acorn +hip +buckeye +coral fungus +agaric +gyromitra +stinkhorn +earthstar +hen-of-the-woods +bolete +ear +toilet tissue diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt b/tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt similarity index 100% rename from tensorflow/contrib/lite/java/demo/app/src/main/assets/labels.txt rename to tensorflow/contrib/lite/java/demo/app/src/main/assets/labels_mobilenet_quant_v1_224.txt diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java index 74737a8b883d23684220dd32bbd7a9e8ab4b2123..9b9fdffab557060f0211a0ce361b002cc7d03956 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java @@ -296,7 +296,8 @@ public class Camera2BasicFragment extends Fragment public void onActivityCreated(Bundle savedInstanceState) { super.onActivityCreated(savedInstanceState); try { - classifier = new ImageClassifier(getActivity()); + // create either a new ImageClassifierQuantizedMobileNet or an ImageClassifierFloatInception + classifier = new ImageClassifierQuantizedMobileNet(getActivity()); } catch (IOException e) { Log.e(TAG, "Failed to initialize an image classifier."); } @@ -658,8 +659,7 @@ public class Camera2BasicFragment extends Fragment showToast("Uninitialized Classifier or invalid context."); return; } - Bitmap bitmap = - textureView.getBitmap(ImageClassifier.DIM_IMG_SIZE_X, ImageClassifier.DIM_IMG_SIZE_Y); + Bitmap bitmap = textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY()); String textToShow = classifier.classifyFrame(bitmap); bitmap.recycle(); showToast(textToShow); diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index e44c5ae6b48eda187079dd3a0a1bc563276d816e..c57bb348c5b386a59327c7b1bc769717ca755269 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -20,6 +20,9 @@ import android.content.res.AssetFileDescriptor; import android.graphics.Bitmap; import android.os.SystemClock; import android.util.Log; + +import org.tensorflow.lite.Interpreter; + import java.io.BufferedReader; import java.io.FileInputStream; import java.io.IOException; @@ -34,20 +37,15 @@ import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.PriorityQueue; -import org.tensorflow.lite.Interpreter; -/** Classifies images with Tensorflow Lite. */ -public class ImageClassifier { +/** + * Classifies images with Tensorflow Lite. + */ +public abstract class ImageClassifier { /** Tag for the {@link Log}. */ private static final String TAG = "TfLiteCameraDemo"; - /** Name of the model file stored in Assets. */ - private static final String MODEL_PATH = "mobilenet_quant_v1_224.tflite"; - - /** Name of the label file stored in Assets. */ - private static final String LABEL_PATH = "labels.txt"; - /** Number of results to show in the UI. */ private static final int RESULTS_TO_SHOW = 3; @@ -56,23 +54,18 @@ public class ImageClassifier { private static final int DIM_PIXEL_SIZE = 3; - static final int DIM_IMG_SIZE_X = 224; - static final int DIM_IMG_SIZE_Y = 224; - /* Preallocated buffers for storing image data in. */ - private int[] intValues = new int[DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y]; + private int[] intValues = new int[getImageSizeX() * getImageSizeY()]; /** An instance of the driver class to run model inference with Tensorflow Lite. */ - private Interpreter tflite; + protected Interpreter tflite; /** Labels corresponding to the output of the vision model. */ private List labelList; /** A ByteBuffer to hold image data, to be feed into Tensorflow Lite as inputs. */ - private ByteBuffer imgData = null; + protected ByteBuffer imgData = null; - /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ - private byte[][] labelProbArray = null; /** multi-stage low pass filter * */ private float[][] filterLabelProbArray = null; @@ -95,10 +88,13 @@ public class ImageClassifier { labelList = loadLabelList(activity); imgData = ByteBuffer.allocateDirect( - DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); + DIM_BATCH_SIZE + * getImageSizeX() + * getImageSizeY() + * DIM_PIXEL_SIZE + * getNumBytesPerChannel()); imgData.order(ByteOrder.nativeOrder()); - labelProbArray = new byte[1][labelList.size()]; - filterLabelProbArray = new float[FILTER_STAGES][labelList.size()]; + filterLabelProbArray = new float[FILTER_STAGES][getNumLabels()]; Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); } @@ -111,7 +107,7 @@ public class ImageClassifier { convertBitmapToByteBuffer(bitmap); // Here's where the magic happens!!! long startTime = SystemClock.uptimeMillis(); - tflite.run(imgData, labelProbArray); + runInference(); long endTime = SystemClock.uptimeMillis(); Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); @@ -125,12 +121,12 @@ public class ImageClassifier { } void applyFilter() { - int numLabels = labelList.size(); + int numLabels = getNumLabels(); // Low pass filter `labelProbArray` into the first stage of the filter. for (int j = 0; j < numLabels; ++j) { filterLabelProbArray[0][j] += - FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]); + FILTER_FACTOR * (getProbability(j) - filterLabelProbArray[0][j]); } // Low pass filter each stage into the next. for (int i = 1; i < FILTER_STAGES; ++i) { @@ -142,7 +138,7 @@ public class ImageClassifier { // Copy the last stage filter output back to `labelProbArray`. for (int j = 0; j < numLabels; ++j) { - labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j]; + setProbability(j, filterLabelProbArray[FILTER_STAGES - 1][j]); } } @@ -156,7 +152,7 @@ public class ImageClassifier { private List loadLabelList(Activity activity) throws IOException { List labelList = new ArrayList(); BufferedReader reader = - new BufferedReader(new InputStreamReader(activity.getAssets().open(LABEL_PATH))); + new BufferedReader(new InputStreamReader(activity.getAssets().open(getLabelPath()))); String line; while ((line = reader.readLine()) != null) { labelList.add(line); @@ -167,7 +163,7 @@ public class ImageClassifier { /** Memory-map the model file in Assets. */ private MappedByteBuffer loadModelFile(Activity activity) throws IOException { - AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(MODEL_PATH); + AssetFileDescriptor fileDescriptor = activity.getAssets().openFd(getModelPath()); FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); FileChannel fileChannel = inputStream.getChannel(); long startOffset = fileDescriptor.getStartOffset(); @@ -185,12 +181,10 @@ public class ImageClassifier { // Convert the image to floating point. int pixel = 0; long startTime = SystemClock.uptimeMillis(); - for (int i = 0; i < DIM_IMG_SIZE_X; ++i) { - for (int j = 0; j < DIM_IMG_SIZE_Y; ++j) { + for (int i = 0; i < getImageSizeX(); ++i) { + for (int j = 0; j < getImageSizeY(); ++j) { final int val = intValues[pixel++]; - imgData.put((byte) ((val >> 16) & 0xFF)); - imgData.put((byte) ((val >> 8) & 0xFF)); - imgData.put((byte) (val & 0xFF)); + addPixelValue(val); } } long endTime = SystemClock.uptimeMillis(); @@ -199,9 +193,9 @@ public class ImageClassifier { /** Prints top-K labels, to be shown in UI as the results. */ private String printTopKLabels() { - for (int i = 0; i < labelList.size(); ++i) { + for (int i = 0; i < getNumLabels(); ++i) { sortedLabels.add( - new AbstractMap.SimpleEntry<>(labelList.get(i), (labelProbArray[0][i] & 0xff) / 255.0f)); + new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i))); if (sortedLabels.size() > RESULTS_TO_SHOW) { sortedLabels.poll(); } @@ -214,4 +208,89 @@ public class ImageClassifier { } return textToShow; } + + /** + * Get the name of the model file stored in Assets. + * + * @return + */ + protected abstract String getModelPath(); + + /** + * Get the name of the label file stored in Assets. + * + * @return + */ + protected abstract String getLabelPath(); + + /** + * Get the image size along the x axis. + * + * @return + */ + protected abstract int getImageSizeX(); + + /** + * Get the image size along the y axis. + * + * @return + */ + protected abstract int getImageSizeY(); + + /** + * Get the number of bytes that is used to store a single color channel value. + * + * @return + */ + protected abstract int getNumBytesPerChannel(); + + /** + * Add pixelValue to byteBuffer. + * + * @param pixelValue + */ + protected abstract void addPixelValue(int pixelValue); + + /** + * Read the probability value for the specified label This is either the original value as it was + * read from the net's output or the updated value after the filter was applied. + * + * @param labelIndex + * @return + */ + protected abstract float getProbability(int labelIndex); + + /** + * Set the probability value for the specified label. + * + * @param labelIndex + * @param value + */ + protected abstract void setProbability(int labelIndex, Number value); + + /** + * Get the normalized probability value for the specified label. This is the final value as it + * will be shown to the user. + * + * @return + */ + protected abstract float getNormalizedProbability(int labelIndex); + + /** + * Run inference using the prepared input in {@link #imgData}. Afterwards, the result will be + * provided by getProbability(). + * + *

This additional method is necessary, because we don't have a common base for different + * primitive data types. + */ + protected abstract void runInference(); + + /** + * Get the total number of labels. + * + * @return + */ + protected int getNumLabels() { + return labelList.size(); + } } diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java new file mode 100644 index 0000000000000000000000000000000000000000..be17b85e0cd93778fd123663595c43b730fb44f7 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierFloatInception.java @@ -0,0 +1,105 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; + +import java.io.IOException; + +/** + * This classifier works with the Inception-v3 slim model. + * It applies floating point inference rather than using a quantized model. + */ +public class ImageClassifierFloatInception extends ImageClassifier { + + /** + * The inception net requires additional normalization of the used input. + */ + private static final int IMAGE_MEAN = 128; + private static final float IMAGE_STD = 128.0f; + + /** + * An array to hold inference results, to be feed into Tensorflow Lite as outputs. + * This isn't part of the super class, because we need a primitive array here. + */ + private float[][] labelProbArray = null; + + /** + * Initializes an {@code ImageClassifier}. + * + * @param activity + */ + ImageClassifierFloatInception(Activity activity) throws IOException { + super(activity); + labelProbArray = new float[1][getNumLabels()]; + } + + @Override + protected String getModelPath() { + // you can download this file from + // https://storage.googleapis.com/download.tensorflow.org/models/tflite/inception_v3_slim_2016_android_2017_11_10.zip + return "inceptionv3_slim_2016.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_imagenet_slim.txt"; + } + + @Override + protected int getImageSizeX() { + return 299; + } + + @Override + protected int getImageSizeY() { + return 299; + } + + @Override + protected int getNumBytesPerChannel() { + // a 32bit float value requires 4 bytes + return 4; + } + + @Override + protected void addPixelValue(int pixelValue) { + imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD); + } + + @Override + protected float getProbability(int labelIndex) { + return labelProbArray[0][labelIndex]; + } + + @Override + protected void setProbability(int labelIndex, Number value) { + labelProbArray[0][labelIndex] = value.floatValue(); + } + + @Override + protected float getNormalizedProbability(int labelIndex) { + // TODO the following value isn't in [0,1] yet, but may be greater. Why? + return getProbability(labelIndex); + } + + @Override + protected void runInference() { + tflite.run(imgData, labelProbArray); + } +} diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java new file mode 100644 index 0000000000000000000000000000000000000000..156c895146940adfe71f111be6e354e02b75ea48 --- /dev/null +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifierQuantizedMobileNet.java @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.tflitecamerademo; + +import android.app.Activity; + +import java.io.IOException; + +/** + * This classifier works with the quantized MobileNet model. + */ +public class ImageClassifierQuantizedMobileNet extends ImageClassifier { + + /** + * An array to hold inference results, to be feed into Tensorflow Lite as outputs. + * This isn't part of the super class, because we need a primitive array here. + */ + private byte[][] labelProbArray = null; + + /** + * Initializes an {@code ImageClassifier}. + * + * @param activity + */ + ImageClassifierQuantizedMobileNet(Activity activity) throws IOException { + super(activity); + labelProbArray = new byte[1][getNumLabels()]; + } + + @Override + protected String getModelPath() { + // you can download this file from + // https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip + return "mobilenet_quant_v1_224.tflite"; + } + + @Override + protected String getLabelPath() { + return "labels_mobilenet_quant_v1_224.txt"; + } + + @Override + protected int getImageSizeX() { + return 224; + } + + @Override + protected int getImageSizeY() { + return 224; + } + + @Override + protected int getNumBytesPerChannel() { + // the quantized model uses a single byte only + return 1; + } + + @Override + protected void addPixelValue(int pixelValue) { + imgData.put((byte) ((pixelValue >> 16) & 0xFF)); + imgData.put((byte) ((pixelValue >> 8) & 0xFF)); + imgData.put((byte) (pixelValue & 0xFF)); + } + + @Override + protected float getProbability(int labelIndex) { + return labelProbArray[0][labelIndex]; + } + + @Override + protected void setProbability(int labelIndex, Number value) { + labelProbArray[0][labelIndex] = value.byteValue(); + } + + @Override + protected float getNormalizedProbability(int labelIndex) { + return (labelProbArray[0][labelIndex] & 0xff) / 255.0f; + } + + @Override + protected void runInference() { + tflite.run(imgData, labelProbArray); + } +} diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index a8ef0daede4f3b7eeffccf77263577002d512e2c..956bd35fe67b3a487f5eb545a827908e12127455 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -104,6 +104,7 @@ cc_library( "add.cc", "basic_rnn.cc", "batch_to_space_nd.cc", + "bidirectional_sequence_lstm.cc", "bidirectional_sequence_rnn.cc", "concatenation.cc", "conv.cc", @@ -111,6 +112,7 @@ cc_library( "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", + "exp.cc", "fully_connected.cc", "gather.cc", "hashtable_lookup.cc", @@ -128,10 +130,12 @@ cc_library( "skip_gram.cc", "space_to_batch_nd.cc", "space_to_depth.cc", + "split.cc", "squeeze.cc", "strided_slice.cc", "sub.cc", "svdf.cc", + "topk_v2.cc", "transpose.cc", "unidirectional_sequence_lstm.cc", "unidirectional_sequence_rnn.cc", @@ -279,6 +283,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "bidirectional_sequence_lstm_test", + size = "small", + srcs = ["bidirectional_sequence_lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "unidirectional_sequence_lstm_test", size = "small", @@ -327,6 +343,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exp_test", + size = "small", + srcs = ["exp_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "mean_test", size = "small", @@ -388,6 +416,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "topk_v2_test", + size = "small", + srcs = ["topk_v2_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "resize_bilinear_test", size = "small", @@ -485,6 +526,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "log_softmax_test", + size = "small", + srcs = ["log_softmax_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "lsh_projection_test", size = "small", @@ -547,6 +601,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "split_test", + size = "small", + srcs = ["split_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "squeeze_test", size = "small", diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 3c5c77815d0f2592ab549152b4d77f45b967a660..6acded3091cb820ba641bac2498799d295d7dc7f 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -337,6 +337,21 @@ TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { } } +TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, 0); + TfLiteTensor* output = GetOutput(context, node, 0); + switch (input->type) { + case kTfLiteFloat32: + optimized_ops::LogSoftmax( + GetTensorData(input), GetTensorDims(input), + GetTensorData(output), GetTensorDims(output)); + return kTfLiteOk; + default: + context->ReportError(context, "Only float32 supported currently."); + return kTfLiteError; + } +} + } // namespace activations TfLiteRegistration* Register_RELU() { @@ -381,6 +396,13 @@ TfLiteRegistration* Register_SOFTMAX() { return &r; } +TfLiteRegistration* Register_LOG_SOFTMAX() { + static TfLiteRegistration r = {activations::Init, activations::Free, + activations::GenericPrepare, + activations::LogSoftmaxEval}; + return &r; +} + } // namespace builtin } // namespace ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index 68d49944e51b043b6b82aa1589d22f6ebed37574..302e52b96db0206f77eb4c8fcffd565b1db0cd3e 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -313,6 +313,47 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { kQuantizedTolerance))); } +// This contains the same test values as the Softmax test, but reference answer +// generated via the following snippet of python: +// logits1 = tf.constant([[0, -6, 2, 4],[3, -2, 10, 1]], dtype=tf.float32) +// logits2 = tf.constant([[0,-6],[2,4],[3,-2],[10,1]], dtype=tf.float32) +// lsm1 = tf.nn.log_softmax(logits1) +// lsm2 = tf.nn.log_softmax(logits2) +// with tf.Session() as sess: +// print('lsm1', sess.run(lsm1)) +// print('lsm2', sess.run(lsm2)) + +TEST(FloatActivationsOpTest, LogSoftmax) { + FloatActivationsOpModel m(BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_FLOAT32, {2, 4}}); + m.SetInput({ + 0, -6, 2, 4, // + 3, -2, 10, 1, // + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + -4.14297, -10.14297, -2.14297, -.142971, // + -7.00104, -12.00104, -.00104087, -9.00104, // + }))); + + // Same input, but a different shape. + FloatActivationsOpModel m2(BuiltinOperator_LOG_SOFTMAX, + /*input=*/{TensorType_FLOAT32, {4, 2}}); + m2.SetInput({ + 0, -6, // + 2, 4, // + 3, -2, // + 10, 1, // + }); + m2.Invoke(); + EXPECT_THAT(m2.GetOutput(), ElementsAreArray(ArrayFloatNear({ + -.00247565, -6.00247, // + -2.12692, -.126928, // + -.00671534, -5.00671, // + -.000123374, -9.00012, // + }))); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..8d70df5e21fab110be238a6f72abe9aac8a75622 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -0,0 +1,863 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace bidirectional_sequence_lstm { + +// Input Tensors of size {max_time, n_batch, n_input} +constexpr int kInputTensor = 0; + +// Forward LSTM cell tensors. +// Input weight tensors of size: {n_cell, n_input} +constexpr int kFwInputToInputWeightsTensor = 1; // Optional +constexpr int kFwInputToForgetWeightsTensor = 2; +constexpr int kFwInputToCellWeightsTensor = 3; +constexpr int kFwInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kFwRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kFwRecurrentToForgetWeightsTensor = 6; +constexpr int kFwRecurrentToCellWeightsTensor = 7; +constexpr int kFwRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kFwCellToInputWeightsTensor = 9; // Optional +constexpr int kFwCellToForgetWeightsTensor = 10; // Optional +constexpr int kFwCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kFwInputGateBiasTensor = 12; // Optional +constexpr int kFwForgetGateBiasTensor = 13; +constexpr int kFwCellGateBiasTensor = 14; +constexpr int kFwOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kFwProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kFwProjectionBiasTensor = 17; // Optional + +// Backward LSTM cell tensors. +// Input weight tensors of size: {n_cell, n_input} +constexpr int kBwInputToInputWeightsTensor = 18; // Optional +constexpr int kBwInputToForgetWeightsTensor = 19; +constexpr int kBwInputToCellWeightsTensor = 20; +constexpr int kBwInputToOutputWeightsTensor = 21; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kBwRecurrentToInputWeightsTensor = 22; // Optional +constexpr int kBwRecurrentToForgetWeightsTensor = 23; +constexpr int kBwRecurrentToCellWeightsTensor = 24; +constexpr int kBwRecurrentToOutputWeightsTensor = 25; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kBwCellToInputWeightsTensor = 26; // Optional +constexpr int kBwCellToForgetWeightsTensor = 27; // Optional +constexpr int kBwCellToOutputWeightsTensor = 28; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kBwInputGateBiasTensor = 29; // Optional +constexpr int kBwForgetGateBiasTensor = 30; +constexpr int kBwCellGateBiasTensor = 31; +constexpr int kBwOutputGateBiasTensor = 32; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kBwProjectionWeightsTensor = 33; // Optional +// Projection bias tensor of size {n_output} +constexpr int kBwProjectionBiasTensor = 34; // Optional + +// Output tensors. +constexpr int kFwScratchBufferTensor = 0; +constexpr int kFwOutputStateTensor = 1; +constexpr int kFwCellStateTensor = 2; +constexpr int kFwOutputTensor = 3; + +constexpr int kBwScratchBufferTensor = 4; +constexpr int kBwOutputStateTensor = 5; +constexpr int kBwCellStateTensor = 6; +constexpr int kBwOutputTensor = 7; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckLstmTensorDimensions( + TfLiteContext* context, TfLiteNode* node, int n_input, int n_output, + int n_cell, int input_to_input_weights_tensor, + int input_to_forget_weights_tensor, int input_to_cell_weights_tensor, + int input_to_output_weights_tensor, int recurrent_to_input_weights_tensor, + int recurrent_to_forget_weights_tensor, + int recurrent_to_cell_weights_tensor, + int recurrent_to_output_weights_tensor, int cell_to_input_weights_tensor, + int cell_to_forget_weights_tensor, int cell_to_output_weights_tensor, + int input_gate_bias_tensor, int forget_gate_bias_tensor, + int cell_gate_bias_tensor, int output_gate_bias_tensor, + int projection_weights_tensor, int projection_bias_tensor) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, input_to_input_weights_tensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, input_to_forget_weights_tensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, input_to_cell_weights_tensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, recurrent_to_input_weights_tensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, recurrent_to_forget_weights_tensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, recurrent_to_cell_weights_tensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, cell_to_input_weights_tensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, cell_to_forget_weights_tensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, cell_to_output_weights_tensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, input_gate_bias_tensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, forget_gate_bias_tensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, cell_gate_bias_tensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, output_gate_bias_tensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, projection_weights_tensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, projection_bias_tensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, kFwInputToInputWeightsTensor, + kFwInputToForgetWeightsTensor, kFwInputToCellWeightsTensor, + kFwInputToOutputWeightsTensor, kFwRecurrentToInputWeightsTensor, + kFwRecurrentToForgetWeightsTensor, kFwRecurrentToCellWeightsTensor, + kFwRecurrentToOutputWeightsTensor, kFwCellToInputWeightsTensor, + kFwCellToForgetWeightsTensor, kFwCellToOutputWeightsTensor, + kFwInputGateBiasTensor, kFwForgetGateBiasTensor, kFwCellGateBiasTensor, + kFwOutputGateBiasTensor, kFwProjectionWeightsTensor, + kFwProjectionBiasTensor); + + CheckLstmTensorDimensions( + context, node, n_input, n_output, n_cell, kBwInputToInputWeightsTensor, + kBwInputToForgetWeightsTensor, kBwInputToCellWeightsTensor, + kBwInputToOutputWeightsTensor, kBwRecurrentToInputWeightsTensor, + kBwRecurrentToForgetWeightsTensor, kBwRecurrentToCellWeightsTensor, + kBwRecurrentToOutputWeightsTensor, kBwCellToInputWeightsTensor, + kBwCellToForgetWeightsTensor, kBwCellToOutputWeightsTensor, + kBwInputGateBiasTensor, kBwForgetGateBiasTensor, kBwCellGateBiasTensor, + kBwOutputGateBiasTensor, kBwProjectionWeightsTensor, + kBwProjectionBiasTensor); + + // Check if Forward and Backward tensors match along required dimensions. + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 35); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 8); + + // Inferring batch size, number of outputs and sequence length and + // number of cells from the input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + TfLiteTensor* fw_input_to_output_weights = + GetInput(context, node, kFwInputToOutputWeightsTensor); + const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1], + n_input); + + TfLiteTensor* fw_recurrent_to_output_weights = + GetInput(context, node, kFwRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->data[0], + n_fw_cell); + const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_fw_output, n_fw_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); + TfLiteTensor* fw_output_state = + GetOutput(context, node, kFwOutputStateTensor); + TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* fw_scratch_buffer = + GetOutput(context, node, kFwScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3); + fw_output_size->data[0] = max_time; + fw_output_size->data[1] = n_batch; + fw_output_size->data[2] = n_fw_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, fw_output, fw_output_size)); + + TfLiteIntArray* fw_output_state_size = TfLiteIntArrayCreate(2); + fw_output_state_size->data[0] = n_batch; + fw_output_state_size->data[1] = n_fw_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output_state, + fw_output_state_size)); + + // Resize the scratch buffer tensor. + TfLiteIntArray* fw_cell_size = TfLiteIntArrayCreate(2); + fw_cell_size->data[0] = n_batch; + fw_cell_size->data[1] = n_fw_cell; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_cell_state, fw_cell_size)); + + // Mark state tensors as persistent tensors. + fw_output_state->allocation_type = kTfLiteArenaRwPersistent; + fw_cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* fw_input_to_input_weights = + GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); + const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); + TfLiteIntArray* fw_scratch_buffer_size = TfLiteIntArrayCreate(2); + fw_scratch_buffer_size->data[0] = n_batch; + if (fw_use_cifg) { + // Reserving space for Cell, Forget, Output gates + fw_scratch_buffer_size->data[1] = n_fw_cell * 3; + } else { + // Reserving space for Input, Cell, Forget, Output gates + fw_scratch_buffer_size->data[1] = n_fw_cell * 4; + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer, + fw_scratch_buffer_size)); + // Same for the backward cell. + TfLiteTensor* bw_input_to_output_weights = + GetInput(context, node, kBwInputToOutputWeightsTensor); + const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], + n_input); + + TfLiteTensor* bw_recurrent_to_output_weights = + GetInput(context, node, kBwRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], + n_bw_cell); + const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteTensor* bw_output_state = + GetOutput(context, node, kBwOutputStateTensor); + TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* bw_scratch_buffer = + GetOutput(context, node, kBwScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); + bw_output_size->data[0] = max_time; + bw_output_size->data[1] = n_batch; + bw_output_size->data[2] = n_bw_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, bw_output, bw_output_size)); + + TfLiteIntArray* bw_output_state_size = TfLiteIntArrayCreate(2); + bw_output_state_size->data[0] = n_batch; + bw_output_state_size->data[1] = n_bw_output; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output_state, + bw_output_state_size)); + + // Resize the scratch buffer tensor. + TfLiteIntArray* bw_cell_size = TfLiteIntArrayCreate(2); + bw_cell_size->data[0] = n_batch; + bw_cell_size->data[1] = n_bw_cell; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_cell_state, bw_cell_size)); + + // Mark state tensors as persistent tensors. + bw_output_state->allocation_type = kTfLiteArenaRwPersistent; + bw_cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* bw_input_to_input_weights = + GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); + const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); + TfLiteIntArray* bw_scratch_buffer_size = TfLiteIntArrayCreate(2); + bw_scratch_buffer_size->data[0] = n_batch; + if (bw_use_cifg) { + // Reserving space for Cell, Forget, Output gates + bw_scratch_buffer_size->data[1] = n_bw_cell * 3; + } else { + // Reserving space for Input, Cell, Forget, Output gates + bw_scratch_buffer_size->data[1] = n_bw_cell * 4; + } + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_scratch_buffer, + bw_scratch_buffer_size)); + return kTfLiteOk; +} + +// Performs an LSTM batch inference step for input specified by input_ptr_batch. +// The LSTM cell is specified by the pointers to its weights (*_weights_ptr) and +// biases (*_bias_ptr), and buffers (*_scratch), along with additional +// parameters: +// - params: various LSTM params including activation, clipping, etc., +// - use_cifg: use coupled input forget gates, +// - use_peephole: whether to use peephole connection or not, +// - n_batch: size of batch, +// - n_cell: number of cells (or units), +// - n_input: the input size, +// - n_output: the output size. +// +// The pointers to the hidden state and the output are updated as a result. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many outputs and hidden states. +void LstmBatchStep( + const float* input_ptr_batch, const float* input_to_input_weights_ptr, + const float* input_to_forget_weights_ptr, + const float* input_to_cell_weights_ptr, + const float* input_to_output_weights_ptr, + const float* recurrent_to_input_weights_ptr, + const float* recurrent_to_forget_weights_ptr, + const float* recurrent_to_cell_weights_ptr, + const float* recurrent_to_output_weights_ptr, + const float* cell_to_input_weights_ptr, + const float* cell_to_forget_weights_ptr, + const float* cell_to_output_weights_ptr, const float* input_gate_bias_ptr, + const float* forget_gate_bias_ptr, const float* cell_bias_ptr, + const float* output_gate_bias_ptr, const float* projection_weights_ptr, + const float* projection_bias_ptr, const TfLiteLSTMParams* params, + bool use_cifg, bool use_peephole, int n_batch, int n_cell, int n_input, + int n_output, float* output_state_ptr, float* cell_state_ptr, + float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch, + float* output_gate_scratch, float* output_ptr_time) { + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias_ptr, n_cell, n_batch, + input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias_ptr, n_cell, n_batch, + forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias_ptr, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias_ptr, n_cell, n_batch, + output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch, + output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, input_gate_scratch, + /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr, + n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr, + n_batch * n_cell, cell_state_ptr); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, + params->cell_clip, cell_state_ptr); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights_ptr != nullptr); + const bool use_projection_bias = (projection_bias_ptr != nullptr); + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output, + n_batch, output_ptr_time); + } else { + tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch, + output_ptr_time, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, + params->proj_clip, output_ptr_time); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_time); + } + tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, + output_state_ptr); +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + // Input tensor. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + // Tensors for the forward cell. + TfLiteTensor* fw_input_to_input_weights = + GetOptionalInputTensor(context, node, kFwInputToInputWeightsTensor); + TfLiteTensor* fw_input_to_forget_weights = + GetInput(context, node, kFwInputToForgetWeightsTensor); + TfLiteTensor* fw_input_to_cell_weights = + GetInput(context, node, kFwInputToCellWeightsTensor); + TfLiteTensor* fw_input_to_output_weights = + GetInput(context, node, kFwInputToOutputWeightsTensor); + + TfLiteTensor* fw_recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kFwRecurrentToInputWeightsTensor); + TfLiteTensor* fw_recurrent_to_forget_weights = + GetInput(context, node, kFwRecurrentToForgetWeightsTensor); + TfLiteTensor* fw_recurrent_to_cell_weights = + GetInput(context, node, kFwRecurrentToCellWeightsTensor); + TfLiteTensor* fw_recurrent_to_output_weights = + GetInput(context, node, kFwRecurrentToOutputWeightsTensor); + + TfLiteTensor* fw_cell_to_input_weights = + GetOptionalInputTensor(context, node, kFwCellToInputWeightsTensor); + TfLiteTensor* fw_cell_to_forget_weights = + GetOptionalInputTensor(context, node, kFwCellToForgetWeightsTensor); + TfLiteTensor* fw_cell_to_output_weights = + GetOptionalInputTensor(context, node, kFwCellToOutputWeightsTensor); + + TfLiteTensor* fw_input_gate_bias = + GetOptionalInputTensor(context, node, kFwInputGateBiasTensor); + TfLiteTensor* fw_forget_gate_bias = + GetInput(context, node, kFwForgetGateBiasTensor); + TfLiteTensor* fw_cell_bias = GetInput(context, node, kFwCellGateBiasTensor); + TfLiteTensor* fw_output_gate_bias = + GetInput(context, node, kFwOutputGateBiasTensor); + + TfLiteTensor* fw_projection_weights = + GetOptionalInputTensor(context, node, kFwProjectionWeightsTensor); + TfLiteTensor* fw_projection_bias = + GetOptionalInputTensor(context, node, kFwProjectionBiasTensor); + + TfLiteTensor* fw_output_state = + GetOutput(context, node, kFwOutputStateTensor); + TfLiteTensor* fw_cell_state = GetOutput(context, node, kFwCellStateTensor); + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); + + // Tensors for the backward cell. + TfLiteTensor* bw_input_to_input_weights = + GetOptionalInputTensor(context, node, kBwInputToInputWeightsTensor); + TfLiteTensor* bw_input_to_forget_weights = + GetInput(context, node, kBwInputToForgetWeightsTensor); + TfLiteTensor* bw_input_to_cell_weights = + GetInput(context, node, kBwInputToCellWeightsTensor); + TfLiteTensor* bw_input_to_output_weights = + GetInput(context, node, kBwInputToOutputWeightsTensor); + + TfLiteTensor* bw_recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kBwRecurrentToInputWeightsTensor); + TfLiteTensor* bw_recurrent_to_forget_weights = + GetInput(context, node, kBwRecurrentToForgetWeightsTensor); + TfLiteTensor* bw_recurrent_to_cell_weights = + GetInput(context, node, kBwRecurrentToCellWeightsTensor); + TfLiteTensor* bw_recurrent_to_output_weights = + GetInput(context, node, kBwRecurrentToOutputWeightsTensor); + + TfLiteTensor* bw_cell_to_input_weights = + GetOptionalInputTensor(context, node, kBwCellToInputWeightsTensor); + TfLiteTensor* bw_cell_to_forget_weights = + GetOptionalInputTensor(context, node, kBwCellToForgetWeightsTensor); + TfLiteTensor* bw_cell_to_output_weights = + GetOptionalInputTensor(context, node, kBwCellToOutputWeightsTensor); + + TfLiteTensor* bw_input_gate_bias = + GetOptionalInputTensor(context, node, kBwInputGateBiasTensor); + TfLiteTensor* bw_forget_gate_bias = + GetInput(context, node, kBwForgetGateBiasTensor); + TfLiteTensor* bw_cell_bias = GetInput(context, node, kBwCellGateBiasTensor); + TfLiteTensor* bw_output_gate_bias = + GetInput(context, node, kBwOutputGateBiasTensor); + + TfLiteTensor* bw_projection_weights = + GetOptionalInputTensor(context, node, kBwProjectionWeightsTensor); + TfLiteTensor* bw_projection_bias = + GetOptionalInputTensor(context, node, kBwProjectionBiasTensor); + + TfLiteTensor* bw_output_state = + GetOutput(context, node, kBwOutputStateTensor); + TfLiteTensor* bw_cell_state = GetOutput(context, node, kBwCellStateTensor); + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + + // n_cell and n_output will be the same size when there is no projection. + const int n_fw_cell = fw_input_to_output_weights->dims->data[0]; + const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool fw_use_cifg = (fw_input_to_input_weights == nullptr); + const bool fw_use_peephole = (fw_cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* fw_scratch_buffer = + GetOutput(context, node, kFwScratchBufferTensor); + float* fw_input_gate_scratch = nullptr; + float* fw_cell_scratch = nullptr; + float* fw_forget_gate_scratch = nullptr; + float* fw_output_gate_scratch = nullptr; + if (fw_use_cifg) { + fw_cell_scratch = fw_scratch_buffer->data.f; + fw_forget_gate_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; + fw_output_gate_scratch = + fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; + } else { + fw_input_gate_scratch = fw_scratch_buffer->data.f; + fw_cell_scratch = fw_scratch_buffer->data.f + n_fw_cell * n_batch; + fw_forget_gate_scratch = + fw_scratch_buffer->data.f + 2 * n_fw_cell * n_batch; + fw_output_gate_scratch = + fw_scratch_buffer->data.f + 3 * n_fw_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* fw_input_to_input_weights_ptr = + (fw_use_cifg) ? nullptr : fw_input_to_input_weights->data.f; + const float* fw_recurrent_to_input_weights_ptr = + (fw_use_cifg) ? nullptr : fw_recurrent_to_input_weights->data.f; + const float* fw_input_gate_bias_ptr = + (fw_use_cifg) ? nullptr : fw_input_gate_bias->data.f; + const float* fw_cell_to_input_weights_ptr = + (fw_use_peephole && !fw_use_cifg) ? fw_cell_to_input_weights->data.f + : nullptr; + const float* fw_cell_to_forget_weights_ptr = + (fw_use_peephole) ? fw_cell_to_forget_weights->data.f : nullptr; + const float* fw_cell_to_output_weights_ptr = + (fw_use_peephole) ? fw_cell_to_output_weights->data.f : nullptr; + const float* fw_projection_weights_ptr = (fw_projection_weights == nullptr) + ? nullptr + : fw_projection_weights->data.f; + const float* fw_projection_bias_ptr = + (fw_projection_bias == nullptr) ? nullptr : fw_projection_bias->data.f; + + // Loop through the sequence. + for (int t = 0; t < max_time; t++) { + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_time = fw_output->data.f + t * n_batch * n_fw_output; + + LstmBatchStep( + input_ptr_batch, fw_input_to_input_weights_ptr, + fw_input_to_forget_weights->data.f, fw_input_to_cell_weights->data.f, + fw_input_to_output_weights->data.f, fw_recurrent_to_input_weights_ptr, + fw_recurrent_to_forget_weights->data.f, + fw_recurrent_to_cell_weights->data.f, + fw_recurrent_to_output_weights->data.f, fw_cell_to_input_weights_ptr, + fw_cell_to_forget_weights_ptr, fw_cell_to_output_weights_ptr, + fw_input_gate_bias_ptr, fw_forget_gate_bias->data.f, + fw_cell_bias->data.f, fw_output_gate_bias->data.f, + fw_projection_weights_ptr, fw_projection_bias_ptr, params, fw_use_cifg, + fw_use_peephole, n_batch, n_fw_cell, n_input, n_fw_output, + fw_output_state->data.f, fw_cell_state->data.f, fw_input_gate_scratch, + fw_forget_gate_scratch, fw_cell_scratch, fw_output_gate_scratch, + output_ptr_time); + } + + // n_cell and n_output will be the same size when there is no projection. + const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; + const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool bw_use_cifg = (bw_input_to_input_weights == nullptr); + const bool bw_use_peephole = (bw_cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* bw_scratch_buffer = + GetOutput(context, node, kBwScratchBufferTensor); + float* bw_input_gate_scratch = nullptr; + float* bw_cell_scratch = nullptr; + float* bw_forget_gate_scratch = nullptr; + float* bw_output_gate_scratch = nullptr; + if (bw_use_cifg) { + bw_cell_scratch = bw_scratch_buffer->data.f; + bw_forget_gate_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; + bw_output_gate_scratch = + bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; + } else { + bw_input_gate_scratch = bw_scratch_buffer->data.f; + bw_cell_scratch = bw_scratch_buffer->data.f + n_bw_cell * n_batch; + bw_forget_gate_scratch = + bw_scratch_buffer->data.f + 2 * n_bw_cell * n_batch; + bw_output_gate_scratch = + bw_scratch_buffer->data.f + 3 * n_bw_cell * n_batch; + } + + // Check optional tensors, the respective pointers can be null. + const float* bw_input_to_input_weights_ptr = + (bw_use_cifg) ? nullptr : bw_input_to_input_weights->data.f; + const float* bw_recurrent_to_input_weights_ptr = + (bw_use_cifg) ? nullptr : bw_recurrent_to_input_weights->data.f; + const float* bw_input_gate_bias_ptr = + (bw_use_cifg) ? nullptr : bw_input_gate_bias->data.f; + const float* bw_cell_to_input_weights_ptr = + (bw_use_peephole && !bw_use_cifg) ? bw_cell_to_input_weights->data.f + : nullptr; + const float* bw_cell_to_forget_weights_ptr = + (bw_use_peephole) ? bw_cell_to_forget_weights->data.f : nullptr; + const float* bw_cell_to_output_weights_ptr = + (bw_use_peephole) ? bw_cell_to_output_weights->data.f : nullptr; + const float* bw_projection_weights_ptr = (bw_projection_weights == nullptr) + ? nullptr + : bw_projection_weights->data.f; + const float* bw_projection_bias_ptr = + (bw_projection_bias == nullptr) ? nullptr : bw_projection_bias->data.f; + + // Loop through the sequence backwards. + for (int t = max_time - 1; t >= 0; t--) { + const float* input_ptr_batch = input->data.f + t * n_batch * n_input; + float* output_ptr_time = bw_output->data.f + t * n_batch * n_bw_output; + + LstmBatchStep( + input_ptr_batch, bw_input_to_input_weights_ptr, + bw_input_to_forget_weights->data.f, bw_input_to_cell_weights->data.f, + bw_input_to_output_weights->data.f, bw_recurrent_to_input_weights_ptr, + bw_recurrent_to_forget_weights->data.f, + bw_recurrent_to_cell_weights->data.f, + bw_recurrent_to_output_weights->data.f, bw_cell_to_input_weights_ptr, + bw_cell_to_forget_weights_ptr, bw_cell_to_output_weights_ptr, + bw_input_gate_bias_ptr, bw_forget_gate_bias->data.f, + bw_cell_bias->data.f, bw_output_gate_bias->data.f, + bw_projection_weights_ptr, bw_projection_bias_ptr, params, bw_use_cifg, + bw_use_peephole, n_batch, n_bw_cell, n_input, n_bw_output, + bw_output_state->data.f, bw_cell_state->data.f, bw_input_gate_scratch, + bw_forget_gate_scratch, bw_cell_scratch, bw_output_gate_scratch, + output_ptr_time); + } + + // Backward step. + return kTfLiteOk; +} + +} // namespace bidirectional_sequence_lstm + +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + bidirectional_sequence_lstm::Prepare, + bidirectional_sequence_lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cca857bac0633ded01d40273d2e9e8dde488d61e --- /dev/null +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -0,0 +1,1411 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Bidirectional LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BidirectionalLSTMOpModel : public SingleOpModel { + public: + BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + int sequence_length, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, + float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_fw_cell_(n_cell), + n_bw_cell_(n_cell), + n_fw_output_(n_output), + n_bw_output_(n_output), + sequence_length_(sequence_length) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + fw_input_to_input_weights_ = AddNullInput(); + } else { + fw_input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + fw_input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + fw_input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + fw_input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + fw_recurrent_to_input_weights_ = AddNullInput(); + } else { + fw_recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + fw_recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + fw_recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + fw_recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + fw_cell_to_input_weights_ = AddNullInput(); + } else { + fw_cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + fw_cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + fw_cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + fw_cell_to_input_weights_ = AddNullInput(); + fw_cell_to_forget_weights_ = AddNullInput(); + fw_cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + fw_input_gate_bias_ = AddNullInput(); + } else { + fw_input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + fw_forget_gate_bias_ = AddInput(TensorType_FLOAT32); + fw_cell_bias_ = AddInput(TensorType_FLOAT32); + fw_output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + fw_projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + fw_projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + fw_projection_bias_ = AddNullInput(); + } + } else { + fw_projection_weights_ = AddNullInput(); + fw_projection_bias_ = AddNullInput(); + } + + fw_scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + fw_output_state_ = AddOutput(TensorType_FLOAT32); + fw_cell_state_ = AddOutput(TensorType_FLOAT32); + fw_output_ = AddOutput(TensorType_FLOAT32); + + if (use_cifg) { + bw_input_to_input_weights_ = AddNullInput(); + } else { + bw_input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + bw_input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + bw_input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + bw_input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + bw_recurrent_to_input_weights_ = AddNullInput(); + } else { + bw_recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + bw_recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + bw_recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + bw_recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + bw_cell_to_input_weights_ = AddNullInput(); + } else { + bw_cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + bw_cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + bw_cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + bw_cell_to_input_weights_ = AddNullInput(); + bw_cell_to_forget_weights_ = AddNullInput(); + bw_cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + bw_input_gate_bias_ = AddNullInput(); + } else { + bw_input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + bw_forget_gate_bias_ = AddInput(TensorType_FLOAT32); + bw_cell_bias_ = AddInput(TensorType_FLOAT32); + bw_output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + bw_projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + bw_projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + bw_projection_bias_ = AddNullInput(); + } + } else { + bw_projection_weights_ = AddNullInput(); + bw_projection_bias_ = AddNullInput(); + } + + bw_scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + bw_output_state_ = AddOutput(TensorType_FLOAT32); + bw_cell_state_ = AddOutput(TensorType_FLOAT32); + bw_output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + // Set weights in forward and backward cells to be the same. + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(fw_input_to_input_weights_, f); + PopulateTensor(bw_input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(fw_input_to_forget_weights_, f); + PopulateTensor(bw_input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(fw_input_to_cell_weights_, f); + PopulateTensor(bw_input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(fw_input_to_output_weights_, f); + PopulateTensor(bw_input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(fw_recurrent_to_input_weights_, f); + PopulateTensor(bw_recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(fw_recurrent_to_forget_weights_, f); + PopulateTensor(bw_recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(fw_recurrent_to_cell_weights_, f); + PopulateTensor(bw_recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(fw_recurrent_to_output_weights_, f); + PopulateTensor(bw_recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(fw_cell_to_input_weights_, f); + PopulateTensor(bw_cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(fw_cell_to_forget_weights_, f); + PopulateTensor(bw_cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(fw_cell_to_output_weights_, f); + PopulateTensor(bw_cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(fw_input_gate_bias_, f); + PopulateTensor(bw_input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(fw_forget_gate_bias_, f); + PopulateTensor(bw_forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(fw_cell_bias_, f); + PopulateTensor(bw_cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(fw_output_gate_bias_, f); + PopulateTensor(bw_output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(fw_projection_weights_, f); + PopulateTensor(bw_projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(fw_projection_bias_, f); + PopulateTensor(bw_projection_bias_, f); + } + + void ResetFwOutputAndCellStates() { + const int zero_buffer_size = n_fw_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(fw_output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + PopulateTensor(fw_cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetBwOutputAndCellStates() { + const int zero_buffer_size = n_bw_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(bw_output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + PopulateTensor(bw_cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetFwOutput() { return ExtractVector(fw_output_); } + std::vector GetBwOutput() { return ExtractVector(bw_output_); } + + int num_inputs() { return n_input_; } + int num_fw_outputs() { return n_fw_output_; } + int num_bw_outputs() { return n_bw_output_; } + int num_fw_cells() { return n_fw_cell_; } + int num_bw_cells() { return n_bw_cell_; } + int num_batches() { return n_batch_; } + int sequence_length() { return sequence_length_; } + + private: + int input_; + int fw_input_to_input_weights_; + int fw_input_to_forget_weights_; + int fw_input_to_cell_weights_; + int fw_input_to_output_weights_; + + int fw_recurrent_to_input_weights_; + int fw_recurrent_to_forget_weights_; + int fw_recurrent_to_cell_weights_; + int fw_recurrent_to_output_weights_; + + int fw_cell_to_input_weights_; + int fw_cell_to_forget_weights_; + int fw_cell_to_output_weights_; + + int fw_input_gate_bias_; + int fw_forget_gate_bias_; + int fw_cell_bias_; + int fw_output_gate_bias_; + + int fw_projection_weights_; + int fw_projection_bias_; + + int bw_input_to_input_weights_; + int bw_input_to_forget_weights_; + int bw_input_to_cell_weights_; + int bw_input_to_output_weights_; + + int bw_recurrent_to_input_weights_; + int bw_recurrent_to_forget_weights_; + int bw_recurrent_to_cell_weights_; + int bw_recurrent_to_output_weights_; + + int bw_cell_to_input_weights_; + int bw_cell_to_forget_weights_; + int bw_cell_to_output_weights_; + + int bw_input_gate_bias_; + int bw_forget_gate_bias_; + int bw_cell_bias_; + int bw_output_gate_bias_; + + int bw_projection_weights_; + int bw_projection_bias_; + + int fw_output_; + int fw_output_state_; + int fw_cell_state_; + int fw_scratch_buffer_; + + int bw_output_; + int bw_output_state_; + int bw_cell_state_; + int bw_scratch_buffer_; + + int n_batch_; + int n_input_; + int n_fw_cell_; + int n_bw_cell_; + int n_fw_output_; + int n_bw_output_; + int sequence_length_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/false, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + // Forward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + // Backward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + // Input should have n_input * sequence_length many values. + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_fw_golden_output[] = { + -0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + static float lstm_bw_golden_output[] = { + -0.0806187, 0.139077, 0.400476, -0.197842, + -0.0332076, 0.123838, 0.309777, -0.17621, + -0.0490733, 0.0739237, 0.067706, -0.0208124}; + + // Resetting cell_state and output_state + lstm.ResetFwOutputAndCellStates(); + lstm.ResetBwOutputAndCellStates(); + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* fw_golden_start = lstm_fw_golden_output; + float* fw_golden_end = + fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length(); + std::vector fw_expected; + fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end); + EXPECT_THAT(lstm.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(fw_expected))); + + float* bw_golden_start = lstm_bw_golden_output; + float* bw_golden_end = + bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length(); + std::vector bw_expected; + bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); + EXPECT_THAT(lstm.GetBwOutput(), + ElementsAreArray(ArrayFloatNear(bw_expected))); + + // Check reversed inputs. + static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; + + // Resetting cell_state and output_state + lstm.ResetFwOutputAndCellStates(); + lstm.ResetBwOutputAndCellStates(); + + batch0_start = lstm_input_reversed; + batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + fw_expected.clear(); + for (int s = 0; s < lstm.sequence_length(); s++) { + fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); + fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); + fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end); + } + EXPECT_THAT(lstm.GetBwOutput(), + ElementsAreArray(ArrayFloatNear(fw_expected))); + + bw_expected.clear(); + for (int s = 0; s < lstm.sequence_length(); s++) { + bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); + bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); + bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end); + } + EXPECT_THAT(lstm.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, + /*use_peephole=*/true, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_fw_golden_output[] = { + -0.36444446, -0.00352185, 0.12886585, -0.05163646, + -0.42312205, -0.01218222, 0.24201041, -0.08124574, + -0.358325, -0.04621704, 0.21641694, -0.06471302}; + static float lstm_bw_golden_output[] = { + -0.401685, -0.0232794, 0.288642, -0.123074, -0.42915, -0.00871577, + 0.20912, -0.103567, -0.166398, -0.00486649, 0.0697471, -0.0537578}; + + // Resetting cell_state and output_state + lstm.ResetFwOutputAndCellStates(); + lstm.ResetBwOutputAndCellStates(); + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* fw_golden_start = lstm_fw_golden_output; + float* fw_golden_end = + fw_golden_start + lstm.num_fw_outputs() * lstm.sequence_length(); + std::vector fw_expected; + fw_expected.insert(fw_expected.end(), fw_golden_start, fw_golden_end); + EXPECT_THAT(lstm.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(fw_expected))); + + float* bw_golden_start = lstm_bw_golden_output; + float* bw_golden_end = + bw_golden_start + lstm.num_bw_outputs() * lstm.sequence_length(); + std::vector bw_expected; + bw_expected.insert(bw_expected.end(), bw_golden_start, bw_golden_end); + EXPECT_THAT(lstm.GetBwOutput(), + ElementsAreArray(ArrayFloatNear(bw_expected))); + + // Check reversed inputs. + static float lstm_input_reversed[] = {1., 1., 3., 4., 2., 3.}; + + // Resetting cell_state and output_state + lstm.ResetFwOutputAndCellStates(); + lstm.ResetBwOutputAndCellStates(); + + batch0_start = lstm_input_reversed; + batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + fw_expected.clear(); + for (int s = 0; s < lstm.sequence_length(); s++) { + fw_golden_start = lstm_fw_golden_output + s * lstm.num_fw_outputs(); + fw_golden_end = fw_golden_start + lstm.num_fw_outputs(); + fw_expected.insert(fw_expected.begin(), fw_golden_start, fw_golden_end); + } + EXPECT_THAT(lstm.GetBwOutput(), + ElementsAreArray(ArrayFloatNear(fw_expected))); + + bw_expected.clear(); + for (int s = 0; s < lstm.sequence_length(); s++) { + bw_golden_start = lstm_bw_golden_output + s * lstm.num_bw_outputs(); + bw_golden_end = bw_golden_start + lstm.num_bw_outputs(); + bw_expected.insert(bw_expected.begin(), bw_golden_start, bw_golden_end); + } + EXPECT_THAT(lstm.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + const int sequence_length = 4; + + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/true, /*use_projection_weights=*/true, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_fw_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + static float lstm_combined_golden_output[][64] = { + { + -0.022014, 0.073544, -0.002235, 0.040068, -0.037136, -0.052788, + 0.075325, -0.029378, 0.024298, -0.07733 , -0.030674, -0.060229, + 0.040599, 0.011608, 0.042005, 0.045977, -0.039225, 0.076294, + 0.000735, 0.032852, -0.069869, -0.053312, 0.073527, -0.028136, + 0.021585, -0.102679, -0.004327, -0.043304, 0.072861, 0.027077, + 0.034558, 0.068292, -0.036292, 0.069832, -0.003032, 0.053829, + -0.043821, -0.072713, 0.085029, -0.040374, 0.020014, -0.104521, + -0.034504, -0.059759, 0.062569, 0.025652, 0.049306, 0.061189, + -0.025146, 0.079643, -0.005188, 0.033080, -0.048079, -0.048082, + 0.069369, -0.028900, 0.024572, -0.077547, -0.022517, -0.054477, + 0.038857, 0.013336, 0.043234, 0.044788}, + { + -0.039186, 0.070792, -0.005913, 0.02642, -0.068274, -0.05022, + 0.061444, -0.031241, 0.014996, -0.094544, -0.004146, -0.03464, + 0.058981, 0.026097, 0.039781, 0.058408, -0.031887, 0.069252, + 0.00576, 0.054062, -0.042801, -0.059974, 0.085272, -0.034453, + 0.026097, -0.0959, -0.031164, -0.058699, 0.06839, 0.020512, + 0.044727, 0.063609, -0.039863, 0.084819, -0.003909, 0.028666, + -0.075677, -0.045125, 0.070379, -0.033895, 0.022111, -0.097184, + -0.004921, -0.040851, 0.062316, 0.017435, 0.041437, 0.064568, + -0.039656, 0.060726, -0.003402, 0.036854, -0.056503, -0.058554, + 0.068588, -0.034879, 0.01352, -0.09962, -0.01434, -0.039505, + 0.065133, 0.024321, 0.038473, 0.062438 + }}; + + // Resetting cell_state and output_state + lstm.ResetFwOutputAndCellStates(); + lstm.ResetBwOutputAndCellStates(); + + for (int i = 0; i < lstm.sequence_length(); i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); + } + + lstm.Invoke(); + + std::vector expected; + for (int i = 0; i < lstm.sequence_length(); i++) { + float* golden_start_batch0 = + lstm_fw_golden_output[0] + i * lstm.num_fw_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs(); + float* golden_start_batch1 = + lstm_fw_golden_output[1] + i * lstm.num_fw_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs(); + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + } + EXPECT_THAT(lstm.GetFwOutput(), ElementsAreArray(ArrayFloatNear(expected))); + + // Check if the sum of forward backward matches the golden. + expected.clear(); + for (int i = 0; i < lstm.sequence_length(); i++) { + float* golden_start_batch0 = + lstm_combined_golden_output[0] + i * lstm.num_fw_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_fw_outputs(); + float* golden_start_batch1 = + lstm_combined_golden_output[1] + i * lstm.num_fw_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_fw_outputs(); + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + } + + std::vector combined; + for (int i = 0; i < lstm.GetFwOutput().size(); ++i) { + combined.push_back(lstm.GetFwOutput()[i] + lstm.GetBwOutput()[i]); + } + EXPECT_THAT(combined, ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 7ff907531805887afea407684fdbaa65e98d619a..a619ada86af64c299f8e518a7493db20f1011a50 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -96,38 +96,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -template -class VectorOfInputs { - public: - VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { - int num_inputs = inputs.size; - - all_data_.reserve(num_inputs); - all_dims_.reserve(num_inputs); - all_dims_ptr_.reserve(num_inputs); - - for (int i = 0; i < num_inputs; ++i) { - TfLiteTensor* input = &context.tensors[inputs.data[i]]; - all_data_.push_back(GetTensorData(input)); - all_dims_.push_back(GetTensorDims(input)); - } - - // Taking the pointer from inside a std::vector is only OK if the vector is - // never modified, so we populate all_dims in the previous loop and then we - // are free to grab iterators here. - for (int i = 0; i < num_inputs; ++i) { - all_dims_ptr_.push_back(&all_dims_[i]); - } - } - const T* const* data() const { return all_data_.data(); } - const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } - - private: - std::vector all_data_; - std::vector> all_dims_; - std::vector*> all_dims_ptr_; -}; - template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = @@ -141,7 +109,7 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { // TODO(ycling): Activation function parameter is ignored. For now we dont have // a model with a Concatenation with fused activation function. #define TF_LITE_CONCATENATION(type, scalar) \ - VectorOfInputs all_inputs(*context, *node->inputs); \ + VectorOfTensors all_inputs(*context, *node->inputs); \ type::Concatenation( \ RemapDim(NumDimensions(output), axis), all_inputs.data(), \ all_inputs.dims(), node->inputs->size, GetTensorData(output), \ diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index 66d2c04bba4a164bbcdcf4b1a097d9aac0b3aeeb..b93a416351cae34b2df8791e382a8a2cd38dcffb 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -51,11 +51,13 @@ enum KernelType { kCblasOptimized, }; +const int kTensorNotAllocated = -1; + struct OpData { // IDs are the arbitrary identifiers used by TF Lite to identify and access // memory buffers. - int im2col_id; - int hwcn_weights_id; + int im2col_id = kTensorNotAllocated; + int hwcn_weights_id = kTensorNotAllocated; TfLitePaddingValues padding; // The scaling factor from input to output (aka the 'real multiplier') can @@ -80,8 +82,6 @@ void* Init(TfLiteContext* context, const char* buffer, size_t length) { // Instead, we allocate a new object to use as scratch space for im2col, and // to carry information from Prepare() to Eval(). auto* data = new OpData; - context->AddTensors(context, 1, &data->im2col_id); - context->AddTensors(context, 1, &data->hwcn_weights_id); gemm_support::IncrementUsageCounter(context); return data; } @@ -107,10 +107,66 @@ void TransposeFloatTensor(TfLiteTensor* input, TfLiteTensor* output) { } } +// Allocate temporary tensors (`im2col`, `hwcn_weights` if necessary). +// Note: `context->AddTensors` might invalidate pointers to existing tensors. +// Therefore the logic to add tensors are isolated into this function. +static TfLiteStatus AllocateTemporaryTensorsIfRequired(TfLiteContext* context, + TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); + + TF_LITE_ENSURE(context, node->inputs->size >= 2); + TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + + int filter_width = filter->dims->data[2]; + int filter_height = filter->dims->data[1]; + + // We don't always need to allocate im2col. It is only used in some versions + // of the optimized Conv. This test just mimics something that happens inside + // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). + data->need_im2col = + (params->stride_width != 1 || params->stride_height != 1 || + filter_width != 1 || filter_height != 1); + // If we're using the optimized multithreaded EigenTensor implementation of + // convolution, it expects the filter weights to be transposed compared to + // the normal TF Lite buffer format. Typical TF Lite weights are + // [filter_count, filter_height, filter_width, input_depth], but for the float + // implementation we need them as [filter_height, filter_width, input_depth, + // filter_count]. We get to that format by transposing, and create a temporary + // buffer to store the results. + // This path is only used for float processing, so only create the buffer if + // we're running with that data type. + data->need_hwcn_weights = (input->type == kTfLiteFloat32); + + int temporaries_count = 0; + if (data->need_im2col) { + data->im2col_index = temporaries_count; + if (data->im2col_id == kTensorNotAllocated) { + context->AddTensors(context, 1, &data->im2col_id); + } + ++temporaries_count; + } + if (data->need_hwcn_weights) { + data->hwcn_weights_index = temporaries_count; + if (data->hwcn_weights_id == kTensorNotAllocated) { + context->AddTensors(context, 1, &data->hwcn_weights_id); + } + ++temporaries_count; + } + + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(temporaries_count); + + return kTfLiteOk; +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_STATUS(AllocateTemporaryTensorsIfRequired(context, node)); + bool hasBias = node->inputs->size == 3; // Check number of inputs/outputs TF_LITE_ENSURE(context, hasBias || node->inputs->size == 2); @@ -118,6 +174,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; TfLiteTensor* input = &context->tensors[node->inputs->data[0]]; TfLiteTensor* filter = &context->tensors[node->inputs->data[1]]; + // Check dimensionality of input, filter TF_LITE_ENSURE_EQ(context, input->dims->size, 4); TF_LITE_ENSURE_EQ(context, filter->dims->size, 4); @@ -199,36 +256,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { if (output_status != kTfLiteOk) return output_status; - // We don't always need to allocate im2col. It is only used in some versions - // of the optimized Conv. This test just mimics something that happens inside - // optimized_ops.h, in order to avoid a DCHECK(!im2col_data). - data->need_im2col = - (params->stride_width != 1 || params->stride_height != 1 || - filter_width != 1 || filter_height != 1); - // If we're using the optimized multithreaded EigenTensor implementation of - // convolution, it expects the filter weights to be transposed compared to - // the normal TF Lite buffer format. Typical TF Lite weights are - // [filter_count, filter_height, filter_width, input_depth], but for the float - // implementation we need them as [filter_height, filter_width, input_depth, - // filter_count]. We get to that format by transposing, and create a temporary - // buffer to store the results. - // This path is only used for float processing, so only create the buffer if - // we're running with that data type. - data->need_hwcn_weights = (data_type == kTfLiteFloat32); - - int temporaries_count = 0; - if (data->need_im2col) { - data->im2col_index = temporaries_count; - ++temporaries_count; - } - if (data->need_hwcn_weights) { - data->hwcn_weights_index = temporaries_count; - ++temporaries_count; - } - - TfLiteIntArrayFree(node->temporaries); - node->temporaries = TfLiteIntArrayCreate(temporaries_count); - if (data->need_im2col) { node->temporaries->data[data->im2col_index] = data->im2col_id; @@ -344,7 +371,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, reference_ops::Conv(GetTensorData(input), GetTensorDims(input), GetTensorData(filter), GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, + params->stride_width, params->stride_height, 1, 1, data->padding.width, data->padding.height, output_activation_min, output_activation_max, GetTensorData(output), GetTensorDims(output), @@ -355,7 +382,7 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, optimized_ops::Conv(GetTensorData(input), GetTensorDims(input), GetTensorData(filter), GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, + params->stride_width, params->stride_height, 1, 1, data->padding.width, data->padding.height, output_activation_min, output_activation_max, GetTensorData(output), GetTensorDims(output), diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9e79b742dc2c80ce4ed9a3aa786814265dcb660 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/exp.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace exp { + +// This file has reference implementation of Exp. +enum KernelType { + kReference, +}; + +struct ExpContext { + ExpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + ExpContext op_context(context, node); + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims); + op_context.output->type = op_context.input->type; + return context->ResizeTensor(context, op_context.output, output_dims); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + ExpContext op_context(context, node); + +#define TF_LITE_EXP(kernel_type, data_type) \ + kernel_type::Exp(GetTensorData(op_context.input), \ + NumElements(op_context.input), \ + GetTensorData(op_context.output)) + + // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_EXP(reference_ops, float); + break; + default: + context->ReportError(context, + "Type %d is currently not supported by Exp.", + op_context.input->type); + return kTfLiteError; + } + } +#undef TF_LITE_EXP + return kTfLiteOk; +} + +} // namespace exp + +TfLiteRegistration* Register_EXP_REF() { + static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare, + exp::Eval}; + return &r; +} + +// TODO(kanlig): add optimized implementation of Exp. +TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/exp_test.cc b/tensorflow/contrib/lite/kernels/exp_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..eed67369a1f30e57cd29a3975a899db41938def0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/exp_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpOpModel : public SingleOpModel { + public: + ExpOpModel(const TensorData& input, const TensorType& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_EXP, BuiltinOptions_ExpOptions, + CreateExpOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int output_; +}; + +TEST(ExpOpTest, FloatTest) { + std::initializer_list data = {1.0, 0.0, -1.0, 1.0, 1.0, -1.0}; + ExpOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index a6ccc99a517abb2c11c03de9044e58d52cffe39e..f47fb04cbaa688b75e763ff9d3cb7df44ac3f166 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -345,6 +345,9 @@ cc_library( ":ios_arm64": [ ":neon_tensor_utils", ], + ":ios_x86_64": [ + ":neon_tensor_utils", + ], ":x86_64": [ ":neon_tensor_utils", ], diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index e2c87df80bd927d823b150ed3799641796dfb4c7..7f6eea2d5d1cfd6f4e2a569760ecbe0d96f754c8 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -573,6 +573,46 @@ struct FloatDepthwiseConvKernel { } }; +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + acc_buffer_ptr += 20; + } + } +}; + template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -926,6 +966,7 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index fc5897896477711c46b06f10003acb10783d12af..dbc4f0d6fdca8279072d6ea225334722d6a89eb2 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -1205,6 +1205,55 @@ struct QuantizedDepthwiseConvKernel { } }; +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8. + // We load the first 16 bytes into filter_u8_{0,1} as usual. + // Then we load the 8 last bytes into filter_u8_x (x for 'extra'). + // This is redundant: the first 4 bytes of filter_u8_x are the same + // as the last 4 bytes of filter_u8_x. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_x = vld1_u8(filter_ptr + 8 * 1 + 4); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_x = vreinterpretq_s16_u16(vmovl_u8(filter_u8_x)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + acc_buffer_ptr += 20; + } + } +}; + template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -1691,6 +1740,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index cd52385f417b469a24b6aa2b15f54ddad5fa9731..3866f86d38a6f200e091497cab2972ed92e25c6b 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -758,14 +758,89 @@ void Im2col(const T* input_data, const Dims<4>& input_dims, int stride, kwidth, byte_zero, output_data, output_dims); } +inline void DilatedConv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, + int dilation_width_factor, int dilation_height_factor, + int pad_width, int pad_height, + float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + // This is a copy of the reference Conv implementation. We do not currently + // have an optimized path for dilation. + (void)im2col_data; // only used in optimized code. + (void)im2col_dims; // only used in optimized code. + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int input_depth = MatchingArraySize(input_dims, 0, filter_dims, 0); + const int output_depth = MatchingArraySize(filter_dims, 3, output_dims, 0); + if (bias_data) { + TFLITE_DCHECK_EQ(ArraySize(filter_dims, 3), ArraySize(bias_dims, 0)); + } + const int input_height = ArraySize(input_dims, 2); + const int input_width = ArraySize(input_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const int filter_width = ArraySize(filter_dims, 1); + const int output_height = ArraySize(output_dims, 2); + const int output_width = ArraySize(output_dims, 1); + for (int batch = 0; batch < batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + for (int out_x = 0; out_x < output_width; ++out_x) { + for (int out_channel = 0; out_channel < output_depth; ++out_channel) { + const int in_x_origin = (out_x * stride_width) - pad_width; + const int in_y_origin = (out_y * stride_height) - pad_height; + float total = 0.f; + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; + // If the location is outside the bounds of the input image, + // use zero as a default value. + if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && + (in_y < input_height)) { + float input_value = input_data[Offset(input_dims, in_channel, + in_x, in_y, batch)]; + float filter_value = + filter_data[Offset(filter_dims, in_channel, filter_x, + filter_y, out_channel)]; + total += (input_value * filter_value); + } + } + } + } + float bias_value = 0.0f; + if (bias_data) { + bias_value = bias_data[Offset(bias_dims, out_channel, 0, 0, 0)]; + } + output_data[Offset(output_dims, out_channel, out_x, out_y, batch)] = + ActivationFunctionWithMinMax(total + bias_value, + output_activation_min, + output_activation_max); + } + } + } + } +} + inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { + int stride_width, int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { + if ((dilation_width_factor != 1) || (dilation_height_factor != 1)) { + return DilatedConv(input_data, input_dims, filter_data, filter_dims, + bias_data, bias_dims, stride_width, stride_height, + dilation_width_factor, dilation_height_factor, pad_width, + pad_height, output_activation_min, output_activation_max, + output_data, output_dims, im2col_data, im2col_dims); + } + (void)im2col_data; (void)im2col_dims; gemmlowp::ScopedProfilingLabel label("Conv"); @@ -805,6 +880,23 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, output_activation_max); } +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + float* output_data, const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, dilation_width_factor, + dilation_height_factor, pad_width, pad_height, output_activation_min, + output_activation_max, output_data, output_dims, im2col_data, + im2col_dims); +} + // legacy, for compatibility with old checked-in code template void Conv(const float* input_data, const Dims<4>& input_dims, @@ -816,7 +908,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims, float output_activation_min, output_activation_max; GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, - stride_width, stride_height, pad_width, pad_height, + stride_width, stride_height, 1, 1, pad_width, pad_height, output_activation_min, output_activation_max, output_data, output_dims, im2col_data, im2col_dims); } @@ -830,7 +922,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { Conv(input_data, input_dims, filter_data, filter_dims, bias_data, - bias_dims, stride, stride, pad_width, pad_height, output_data, + bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data, output_dims, im2col_data, im2col_dims); } @@ -2081,6 +2173,198 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, output_state_map.tanh(); } +#ifdef GEMMLOWP_NEON +// In the common case of batch size 1, a fully-connected node degenerates +// to a matrix*vector product. LSTM cells contain a fully-connected node; +// when quantized, this becomes a special type of GEMV operation where +// the output is 16bit-quantized, thus needs its own special path. +inline void GEMVForLstmCell(const uint8* input_data, const Dims<4>& input_dims, + const uint8* weights_data, + const Dims<4>& weights_dims, + uint8 weights_zero_point, const int32* bias_data, + const Dims<4>& bias_dims, int32 accum_multiplier, + int accum_shift, int16* output_data, + const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("GEMVForLstmCell"); + TFLITE_DCHECK(IsPackedWithoutStrides(input_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(bias_dims)); + TFLITE_DCHECK(IsPackedWithoutStrides(output_dims)); + TFLITE_DCHECK_EQ(ArraySize(output_dims, 1) * ArraySize(output_dims, 2) * + ArraySize(output_dims, 3), + 1); + const int input_size = input_dims.strides[3]; + const int output_size = MatchingArraySize(weights_dims, 1, output_dims, 0); + // This special fast path for quantized LSTM cells does not try to support + // odd sizes that we haven't encountered in any LSTM cell, that would + // require special code (that would go untested until any LSTM cell + // exercises it). We just guard our assumptions about size evenness with + // the following assertions. + TFLITE_DCHECK(!(output_size % 4)); + TFLITE_DCHECK(!(input_size % 8)); + const int32* bias_ptr = bias_data; + int16* output_ptr = output_data; + for (int out = 0; out < output_size; out += 4) { + int32x4_t acc_0 = vdupq_n_s32(0); + int32x4_t acc_1 = vdupq_n_s32(0); + int32x4_t acc_2 = vdupq_n_s32(0); + int32x4_t acc_3 = vdupq_n_s32(0); + const int16x8_t input_offset_vec = vdupq_n_s16(-128); + const int16x8_t weights_offset_vec = vdupq_n_s16(-weights_zero_point); + int in = 0; + // Handle 16 levels of depth at a time. + for (; in <= input_size - 16; in += 16) { + const uint8x16_t input_val_u8 = vld1q_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x16_t weights_val_u8_0 = vld1q_u8(weights_ptr + 0 * input_size); + uint8x16_t weights_val_u8_1 = vld1q_u8(weights_ptr + 1 * input_size); + uint8x16_t weights_val_u8_2 = vld1q_u8(weights_ptr + 2 * input_size); + uint8x16_t weights_val_u8_3 = vld1q_u8(weights_ptr + 3 * input_size); + int16x8_t input_val_0, input_val_1; + const uint8x8_t low = vget_low_u8(input_val_u8); + const uint8x8_t high = vget_high_u8(input_val_u8); + input_val_0 = vreinterpretq_s16_u16(vmovl_u8(low)); + input_val_1 = vreinterpretq_s16_u16(vmovl_u8(high)); + input_val_0 = vaddq_s16(input_val_0, input_offset_vec); + input_val_1 = vaddq_s16(input_val_1, input_offset_vec); + int16x8_t weights_val_0_0, weights_val_1_0, weights_val_2_0, + weights_val_3_0; + int16x8_t weights_val_0_1, weights_val_1_1, weights_val_2_1, + weights_val_3_1; + weights_val_0_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_0_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_0))), + weights_offset_vec); + weights_val_1_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_1_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_1))), + weights_offset_vec); + weights_val_2_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_2_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_2))), + weights_offset_vec); + weights_val_3_0 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(weights_val_u8_3))), + weights_offset_vec); + weights_val_3_1 = vaddq_s16( + vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(weights_val_u8_3))), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_0), + vget_low_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_0), + vget_low_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_0), + vget_low_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_0), + vget_low_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_0), + vget_high_s16(input_val_0)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_0), + vget_high_s16(input_val_0)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_0), + vget_high_s16(input_val_0)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_0), + vget_high_s16(input_val_0)); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0_1), + vget_low_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1_1), + vget_low_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2_1), + vget_low_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3_1), + vget_low_s16(input_val_1)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0_1), + vget_high_s16(input_val_1)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1_1), + vget_high_s16(input_val_1)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2_1), + vget_high_s16(input_val_1)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3_1), + vget_high_s16(input_val_1)); + } + // Handle 8 levels of depth at a time. + for (; in < input_size; in += 8) { + const uint8x8_t input_val_u8 = vld1_u8(input_data + in); + const uint8* weights_ptr = weights_data + in + out * input_size; + uint8x8_t weights_val_u8_0 = vld1_u8(weights_ptr + 0 * input_size); + uint8x8_t weights_val_u8_1 = vld1_u8(weights_ptr + 1 * input_size); + uint8x8_t weights_val_u8_2 = vld1_u8(weights_ptr + 2 * input_size); + uint8x8_t weights_val_u8_3 = vld1_u8(weights_ptr + 3 * input_size); + int16x8_t input_val; + input_val = vreinterpretq_s16_u16(vmovl_u8(input_val_u8)); + input_val = vaddq_s16(input_val, input_offset_vec); + int16x8_t weights_val_0, weights_val_1, weights_val_2, weights_val_3; + weights_val_0 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_0)), + weights_offset_vec); + weights_val_1 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_1)), + weights_offset_vec); + weights_val_2 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_2)), + weights_offset_vec); + weights_val_3 = + vaddq_s16(vreinterpretq_s16_u16(vmovl_u8(weights_val_u8_3)), + weights_offset_vec); + acc_0 = vmlal_s16(acc_0, vget_low_s16(weights_val_0), + vget_low_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_low_s16(weights_val_1), + vget_low_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_low_s16(weights_val_2), + vget_low_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_low_s16(weights_val_3), + vget_low_s16(input_val)); + acc_0 = vmlal_s16(acc_0, vget_high_s16(weights_val_0), + vget_high_s16(input_val)); + acc_1 = vmlal_s16(acc_1, vget_high_s16(weights_val_1), + vget_high_s16(input_val)); + acc_2 = vmlal_s16(acc_2, vget_high_s16(weights_val_2), + vget_high_s16(input_val)); + acc_3 = vmlal_s16(acc_3, vget_high_s16(weights_val_3), + vget_high_s16(input_val)); + } + // Horizontally reduce accumulators + int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, + pairwise_reduced_acc_2, pairwise_reduced_acc_3; + pairwise_reduced_acc_0 = + vpadd_s32(vget_low_s32(acc_0), vget_high_s32(acc_0)); + pairwise_reduced_acc_1 = + vpadd_s32(vget_low_s32(acc_1), vget_high_s32(acc_1)); + pairwise_reduced_acc_2 = + vpadd_s32(vget_low_s32(acc_2), vget_high_s32(acc_2)); + pairwise_reduced_acc_3 = + vpadd_s32(vget_low_s32(acc_3), vget_high_s32(acc_3)); + const int32x2_t reduced_lo = + vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); + const int32x2_t reduced_hi = + vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); + int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); + // Add bias values. + int32x4_t bias_vec = vld1q_s32(bias_ptr); + bias_ptr += 4; + reduced = vaddq_s32(reduced, bias_vec); + int left_shift = accum_shift > 0 ? accum_shift : 0; + int right_shift = accum_shift > 0 ? 0 : -accum_shift; + reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); + // Multiply by the fixed-point multiplier. + reduced = vqrdmulhq_n_s32(reduced, accum_multiplier); + // Rounding-shift-right. + using gemmlowp::RoundingDivideByPOT; + reduced = RoundingDivideByPOT(reduced, right_shift); + // Narrow values down to 16 bit signed. + const int16x4_t res16 = vqmovn_s32(reduced); + vst1_s16(output_ptr, res16); + output_ptr += 4; + } +} +#endif + // Quantized LSTM cell. Currently just a copy of the reference impl in // reference_ops.h. See the big function comment there, not replicating it // here. @@ -2095,7 +2379,8 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, const Dims<4>& activ_temp_dims, int32 weights_zero_point, - int32 accum_multiplier, int accum_shift) { + int32 accum_multiplier, int accum_shift, + gemmlowp::GemmContext* gemm_context) { gemmlowp::ScopedProfilingLabel label( "LstmCell/quantized (8bit external, 16bit internal)"); // Gather dimensions information, and perform consistency checks. @@ -2144,42 +2429,131 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, // integers, and the output is 16-bit fixed-point with 3 integer bits so // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that // is explained in the function comment above. - for (int b = 0; b < fc_batches; ++b) { - for (int out_c = 0; out_c < fc_output_depth; ++out_c) { - // Internal accumulation. - // Initialize accumulator with the bias-value. - int32 accum = bias_data_int32[out_c]; - // Accumulation loop. - for (int d = 0; d < fc_accum_depth; ++d) { - int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; - int16 weights_val = - weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; - accum += input_val * weights_val; - } - // Down-scale the final int32 accumulator to the scale used by our - // (16-bit, using 3 integer bits) fixed-point format. The quantized - // multiplier and shift here have been pre-computed offline - // (e.g. by toco). - // Note that the implicit assumption here, that this multiplier is smaller - // than one, is equivalent to the assumption that the fully-connected - // weights min-max is enclosed within [-4, 4] (it may be narrower). - // If that eventually fails, offline tools (e.g. toco) will fail early - // and that will be easy to support as needed. For now, assuming that - // this multiplier is less than one allows us to use a simpler, more - // accurate implementation. - accum = - MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); - // Saturate, cast to int16, and store to the temporary activations array. - accum = std::max(-32768, std::min(32767, accum)); - activ_temp_data_int16[out_c + fc_output_depth * b] = accum; - } + bool gemm_already_performed = false; +#ifdef GEMMLOWP_NEON + if (fc_batches == 1 && !(fc_output_depth % 4) && !(fc_accum_depth % 8)) { + GEMVForLstmCell(concat_temp_data_uint8, concat_temp_dims, + weights_data_uint8, weights_dims, weights_zero_point, + bias_data_int32, bias_dims, accum_multiplier, accum_shift, + activ_temp_data_int16, activ_temp_dims); + gemm_already_performed = true; + } +#endif + if (!gemm_already_performed) { + gemmlowp::MatrixMap + weights_matrix(weights_data_uint8, fc_output_depth, fc_accum_depth); + gemmlowp::MatrixMap input_matrix( + concat_temp_data_uint8, fc_accum_depth, fc_batches); + gemmlowp::MatrixMap output_matrix( + activ_temp_data_int16, fc_output_depth, fc_batches); + typedef gemmlowp::VectorMap + ColVectorMap; + ColVectorMap bias_vector(bias_data_int32, fc_output_depth); + gemmlowp::OutputStageBiasAddition bias_addition_stage; + bias_addition_stage.bias_vector = bias_vector; + gemmlowp::OutputStageScaleInt32ByFixedPointAndExponent scale_stage; + scale_stage.result_offset_after_shift = 0; + scale_stage.result_fixedpoint_multiplier = accum_multiplier; + scale_stage.result_exponent = accum_shift; + gemmlowp::OutputStageSaturatingCastToInt16 saturating_cast_int16_stage; + auto output_pipeline = std::make_tuple(bias_addition_stage, scale_stage, + saturating_cast_int16_stage); + gemmlowp::GemmWithOutputPipeline< + uint8, int16, gemmlowp::L8R8WithLhsNonzeroBitDepthParams>( + gemm_context, weights_matrix, input_matrix, &output_matrix, + -weights_zero_point, -128, output_pipeline); } // Rest of the LSTM cell: tanh and logistic math functions, and some adds // and muls, all done in 16-bit fixed-point. const int outer_size = batches * width * height; + const int16* input_gate_input_ptr = activ_temp_data_int16; + const int16* input_modulation_gate_input_ptr = + activ_temp_data_int16 + output_depth; + const int16* forget_gate_input_ptr = activ_temp_data_int16 + 2 * output_depth; + const int16* output_gate_input_ptr = activ_temp_data_int16 + 3 * output_depth; + const int16* prev_state_ptr = prev_state_data_int16; + int16* output_state_data_ptr = output_state_data_int16; + uint8* output_activ_data_ptr = output_activ_data_uint8; + for (int b = 0; b < outer_size; ++b) { - for (int c = 0; c < output_depth; ++c) { + int c = 0; +#ifdef GEMMLOWP_NEON + for (; c <= output_depth - 8; c += 8) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw(vld1q_s16(input_gate_input_ptr)); + input_gate_input_ptr += 8; + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = + F3::FromRaw(vld1q_s16(input_modulation_gate_input_ptr)); + input_modulation_gate_input_ptr += 8; + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw(vld1q_s16(forget_gate_input_ptr)); + forget_gate_input_ptr += 8; + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw(vld1q_s16(output_gate_input_ptr)); + output_gate_input_ptr += 8; + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(vld1q_s16(prev_state_ptr)); + prev_state_ptr += 8; + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal Tanh node, still in fixed-point. + // Since a Tanh fixed-point implementation is specialized for a given + // number or integer bits, and each specialization can have a substantial + // code size, and we already used above a Tanh on an input with 3 integer + // bits, and per the table in the above function comment there is no + // significant accuracy to be lost by clamping to [-8, +8] for a + // 3-integer-bits representation, let us just do that. This helps people + // porting this to targets where code footprint must be minimized. + F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); + // Store the new internal state back to memory, as 16-bit integers. + // Note: here we store the original value with StateIntegerBits, not + // the rescaled 3-integer-bits value fed to tanh. + vst1q_s16(output_state_data_ptr, new_state.raw()); + output_state_data_ptr += 8; + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16x8_t rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int8x8_t int8_output_activ = vqmovn_s16(rescaled_output_activ); + uint8x8_t uint8_output_activ = + vadd_u8(vdup_n_u8(128), vreinterpret_u8_s8(int8_output_activ)); + vst1_u8(output_activ_data_ptr, uint8_output_activ); + output_activ_data_ptr += 8; + } +#endif + for (; c < output_depth; ++c) { // Define the fixed-point data types that we will use here. All use // int16 as the underlying integer type i.e. all are 16-bit fixed-point. // They only differ by the number of integral vs. fractional bits, @@ -2199,45 +2573,55 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, // on the StateIntegerBits template parameter above. using FS = gemmlowp::FixedPoint; // Implementation of input gate, using fixed-point logistic function. - F3 input_gate_input = F3::FromRaw( - activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); + F3 input_gate_input = F3::FromRaw(*input_gate_input_ptr++); F0 input_gate_output = gemmlowp::logistic(input_gate_input); // Implementation of input modulation gate, using fixed-point tanh // function. - F3 input_modulation_gate_input = F3::FromRaw( - activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); + F3 input_modulation_gate_input = + F3::FromRaw(*input_modulation_gate_input_ptr++); F0 input_modulation_gate_output = gemmlowp::tanh(input_modulation_gate_input); // Implementation of forget gate, using fixed-point logistic function. - F3 forget_gate_input = F3::FromRaw( - activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); + F3 forget_gate_input = F3::FromRaw(*forget_gate_input_ptr++); F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); // Implementation of output gate, using fixed-point logistic function. - F3 output_gate_input = F3::FromRaw( - activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); + F3 output_gate_input = F3::FromRaw(*output_gate_input_ptr++); F0 output_gate_output = gemmlowp::logistic(output_gate_input); // Implementation of internal multiplication nodes, still in fixed-point. F0 input_times_input_modulation = input_gate_output * input_modulation_gate_output; - FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]); + FS prev_state = FS::FromRaw(*prev_state_ptr++); FS prev_state_times_forget_state = forget_gate_output * prev_state; // Implementation of internal addition node, saturating. FS new_state = gemmlowp::SaturatingAdd( gemmlowp::Rescale(input_times_input_modulation), prev_state_times_forget_state); - // Implementation of last internal tanh node, still in fixed-point. - F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Implementation of last internal Tanh node, still in fixed-point. + // Since a Tanh fixed-point implementation is specialized for a given + // number or integer bits, and each specialization can have a substantial + // code size, and we already used above a Tanh on an input with 3 integer + // bits, and per the table in the above function comment there is no + // significant accuracy to be lost by clamping to [-8, +8] for a + // 3-integer-bits representation, let us just do that. This helps people + // porting this to targets where code footprint must be minimized. + F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); // Store the new internal state back to memory, as 16-bit integers. - output_state_data_int16[b * output_depth + c] = new_state.raw(); + // Note: here we store the original value with StateIntegerBits, not + // the rescaled 3-integer-bits value fed to tanh. + *output_state_data_ptr++ = new_state.raw(); // Down-scale the output activations to 8-bit integers, saturating, // and store back to memory. int16 rescaled_output_activ = gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); int16 clamped_output_activ = std::max(-128, std::min(127, rescaled_output_activ)); - output_activ_data_uint8[b * output_depth + c] = - 128 + clamped_output_activ; + *output_activ_data_ptr++ = 128 + clamped_output_activ; } + input_gate_input_ptr += 3 * output_depth; + input_modulation_gate_input_ptr += 3 * output_depth; + forget_gate_input_ptr += 3 * output_depth; + output_gate_input_ptr += 3 * output_depth; } } @@ -2866,74 +3250,231 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, using FixedPointAccum = gemmlowp::FixedPoint; using FixedPoint0 = gemmlowp::FixedPoint; - gemmlowp::ScopedProfilingLabel label("Softmax"); + gemmlowp::ScopedProfilingLabel label("Softmax/8bit"); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + const int outer_size = batches * height * width; + + for (int b = 0; b < outer_size; ++b) { + const uint8* input_data_ptr = input_data + b * depth; + uint8* output_data_ptr = output_data + b * depth; + + // Determine the largest entry in the current row + uint8 max_in_row = 0; + { + int c = 0; +#ifdef USE_NEON + uint8x16_t max16_0 = vdupq_n_u8(0); + uint8x16_t max16_1 = vdupq_n_u8(0); + for (; c <= depth - 32; c += 32) { + max16_0 = vmaxq_u8(max16_0, vld1q_u8(input_data_ptr + c + 0)); + max16_1 = vmaxq_u8(max16_1, vld1q_u8(input_data_ptr + c + 16)); + } + uint8x16_t max16 = vmaxq_u8(max16_0, max16_1); + if (c <= depth - 16) { + max16 = vmaxq_u8(max16, vld1q_u8(input_data_ptr + c)); + c += 16; + } + uint8x8_t max8 = vmax_u8(vget_low_u8(max16), vget_high_u8(max16)); + if (c <= depth - 8) { + max8 = vmax_u8(max8, vld1_u8(input_data_ptr + c)); + c += 8; + } + uint8x8_t max4 = vmax_u8(max8, vext_u8(max8, max8, 4)); + uint8x8_t max2 = vmax_u8(max4, vext_u8(max4, max4, 2)); + uint8x8_t max1 = vpmax_u8(max2, max2); + max_in_row = vget_lane_u8(max1, 0); +#endif + for (; c < depth; ++c) { + max_in_row = std::max(max_in_row, input_data_ptr[c]); + } + } + +#ifdef USE_NEON + using FixedPointAccumInt32x4 = + gemmlowp::FixedPoint; + using FixedPointScaledDiffInt32x4 = + gemmlowp::FixedPoint; + using FixedPoint0Int32x4 = gemmlowp::FixedPoint; + FixedPoint0Int32x4 input_beta_multiplier_f0 = + FixedPoint0Int32x4::FromScalarRaw(input_beta_multiplier); + int16x8_t max_in_row_s16 = vdupq_n_s16(max_in_row); +#endif + + // Compute the sum of exponentials of the differences of entries in the + // current row from the largest entry in the current row. + FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + { + int c = 0; +#ifdef USE_NEON + int32x4_t diff_min_s32 = vdupq_n_s32(diff_min); + FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero(); + FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero(); + FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero(); + for (; c <= depth - 8; c += 8) { + uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); + int16x8_t input_diff_s16 = + vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); + int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); + int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); + int32x4_t mask_0 = + gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32); + int32x4_t mask_1 = + gemmlowp::MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32); + FixedPointScaledDiffInt32x4 scaled_diff_0 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); + FixedPointScaledDiffInt32x4 scaled_diff_1 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); + FixedPointAccumInt32x4 exps_0 = + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_0)); + FixedPointAccumInt32x4 exps_1 = + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_1)); + FixedPointAccumInt32x4 masked_exps_0 = + SelectUsingMask(mask_0, exps_0, zeros); + FixedPointAccumInt32x4 masked_exps_1 = + SelectUsingMask(mask_1, exps_1, zeros); + sum_of_exps_0 = sum_of_exps_0 + masked_exps_0; + sum_of_exps_1 = sum_of_exps_1 + masked_exps_1; + } + int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw(); + int32x2_t sum_of_exps_reduced_2 = + vadd_s32(vget_low_s32(sum_of_exps_reduced_4), + vget_high_s32(sum_of_exps_reduced_4)); + int32x2_t sum_of_exps_reduced_1 = + vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2); + sum_of_exps = + FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0)); +#endif + for (; c < depth; ++c) { + int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + sum_of_exps = + sum_of_exps + gemmlowp::Rescale( + exp_on_negative_values(scaled_diff_f8)); + } + } + } + + // Compute the fixed-point multiplier and shift that we need to apply to + // perform a division by the above-computed sum-of-exponentials. + int32 fixed_sum_of_exps = sum_of_exps.raw(); + int headroom_plus_one = + __builtin_clz(static_cast(fixed_sum_of_exps)); + // This is the number of bits to the left of the binary point above 1.0. + // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and + // no later adjustment will be needed. + int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; + int32 shifted_sum_minus_one = static_cast( + (static_cast(fixed_sum_of_exps) << headroom_plus_one) - + (static_cast(1) << 31)); + FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( + FixedPoint0::FromRaw(shifted_sum_minus_one)); + + // Compute the quotients of exponentials of differences of entries in the + // current row from the largest entry, over the previously-computed sum of + // exponentials. + { + int c = 0; +#ifdef USE_NEON + int16x8_t diff_min_s16 = vdupq_n_s16(diff_min); + for (; c <= depth - 8; c += 8) { + uint16x8_t input_u16 = vmovl_u8(vld1_u8(input_data_ptr + c)); + int16x8_t input_diff_s16 = + vsubq_s16(vreinterpretq_s16_u16(input_u16), max_in_row_s16); + int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); + int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); + uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16)); + FixedPointScaledDiffInt32x4 scaled_diff_0 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_0, input_beta_left_shift)); + FixedPointScaledDiffInt32x4 scaled_diff_1 = + input_beta_multiplier_f0 * + FixedPointScaledDiffInt32x4::FromRaw( + gemmlowp::ShiftLeft(input_diff_s32_1, input_beta_left_shift)); + FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0); + FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1); + int32x4_t output_s32_0 = gemmlowp::RoundingDivideByPOT( + vqrdmulhq_n_s32(exp_0.raw(), shifted_scale.raw()), + num_bits_over_unit + 31 - 8); + int32x4_t output_s32_1 = gemmlowp::RoundingDivideByPOT( + vqrdmulhq_n_s32(exp_1.raw(), shifted_scale.raw()), + num_bits_over_unit + 31 - 8); + int16x8_t output_s16 = + vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1)); + uint8x8_t output_u8 = vqmovun_s16(output_s16); + uint8x8_t masked_output = vbsl_u8(mask, output_u8, vdup_n_u8(0)); + vst1_u8(output_data_ptr + c, masked_output); + } +#endif + for (; c < depth; ++c) { + int32 input_diff = static_cast(input_data_ptr[c]) - max_in_row; + if (input_diff >= diff_min) { + const int32 input_diff_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_diff, input_beta_multiplier, input_beta_left_shift); + const FixedPointScaledDiff scaled_diff_f8 = + FixedPointScaledDiff::FromRaw(input_diff_rescaled); + + FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); + int32 unsat_output = gemmlowp::RoundingDivideByPOT( + (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); + + output_data_ptr[c] = std::max(std::min(unsat_output, 255), 0); + + } else { + output_data_ptr[c] = 0; + } + } + } + } +} + +// TODO(myenik): This is the same as the reference implementation, not actually +// optimized yet. +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); const int height = MatchingArraySize(input_dims, 2, output_dims, 2); const int width = MatchingArraySize(input_dims, 1, output_dims, 1); const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); for (int b = 0; b < batches; ++b) { - for (int x = 0; x < width; ++x) { - for (int y = 0; y < height; ++y) { - uint8 max_in_row = 0; + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C))) + float max = std::numeric_limits::lowest(); for (int c = 0; c < depth; ++c) { - max_in_row = - std::max(max_in_row, input_data[Offset(input_dims, c, x, y, b)]); + max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); } - FixedPointAccum sum_of_exps = FixedPointAccum::Zero(); + // Compute sum. + float sum = 0.f; for (int c = 0; c < depth; ++c) { - int32 input_diff = - static_cast(input_data[Offset(input_dims, c, x, y, b)]) - - max_in_row; - if (input_diff >= diff_min) { - const int32 input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - sum_of_exps = - sum_of_exps + gemmlowp::Rescale( - exp_on_negative_values(scaled_diff_f8)); - } + sum += std::exp(input_data[Offset(input_dims, c, x, y, b)] - max); } - int32 fixed_sum_of_exps = sum_of_exps.raw(); - // TODO(starka): Use a NEON intrinsic like vclzq_u32 instead. - int headroom_plus_one = - __builtin_clz(static_cast(fixed_sum_of_exps)); - // This is the number of bits to the left of the binary point above 1.0. - // Consider fixed_sum_of_exps=1.25. In that case shifted_scale=0.8 and - // no later adjustment will be needed. - int num_bits_over_unit = kAccumulationIntegerBits - headroom_plus_one; - int32 shifted_sum_minus_one = static_cast( - (static_cast(fixed_sum_of_exps) << headroom_plus_one) - - (static_cast(1) << 31)); - - FixedPoint0 shifted_scale = gemmlowp::one_over_one_plus_x_for_x_in_0_1( - FixedPoint0::FromRaw(shifted_sum_minus_one)); - + // Compute result. + const float log_sum = std::log(sum); for (int c = 0; c < depth; ++c) { - int32 input_diff = - static_cast(input_data[Offset(input_dims, c, x, y, b)]) - - max_in_row; - if (input_diff >= diff_min) { - const int32 input_diff_rescaled = - MultiplyByQuantizedMultiplierGreaterThanOne( - input_diff, input_beta_multiplier, input_beta_left_shift); - const FixedPointScaledDiff scaled_diff_f8 = - FixedPointScaledDiff::FromRaw(input_diff_rescaled); - - FixedPoint0 exp_in_0 = exp_on_negative_values(scaled_diff_f8); - int32 unsat_output = gemmlowp::RoundingDivideByPOT( - (shifted_scale * exp_in_0).raw(), num_bits_over_unit + 31 - 8); - - output_data[Offset(output_dims, c, x, y, b)] = - std::max(std::min(unsat_output, 255), 0); - - } else { - output_data[Offset(output_dims, c, x, y, b)] = 0; - } + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] - max - log_sum; } } } @@ -4155,6 +4696,35 @@ void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, } } +template +void Transpose(const T* input, const Dims<4>& input_dims, T* output, + const Dims<4>& output_dims, const int* permuted_axes) { + int out_sizes[4]; + // Compute the inverse permutation array so we can do an output centered + // transpose. Also, check to make sure output_dims is matching input_dims. + for (int k = 0; k < 4; k++) { + out_sizes[k] = + MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); + } + + // Naive transpose loop (iterate on output index and compute input index). + int o[4]; // loop index (on output). + int i[4]; + for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { + i[permuted_axes[3]] = o[3]; + for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { + i[permuted_axes[2]] = o[2]; + for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { + i[permuted_axes[1]] = o[1]; + for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { + i[permuted_axes[0]] = o[0]; + output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; + } + } + } + } +} + } // namespace optimized_ops } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index f18543f4e4b5319e3b5848812576f8e5fb8165e8..53de21697b95039e32383a7a9d99c2e3168068c2 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -157,11 +157,11 @@ inline void NdArrayDescsForElementwiseBroadcast(const Dims& input0_dims, inline void Conv(const float* input_data, const Dims<4>& input_dims, const float* filter_data, const Dims<4>& filter_dims, const float* bias_data, const Dims<4>& bias_dims, - int stride_width, int stride_height, int pad_width, - int pad_height, float output_activation_min, - float output_activation_max, float* output_data, - const Dims<4>& output_dims, float* im2col_data, - const Dims<4>& im2col_dims) { + int stride_width, int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims, + float* im2col_data, const Dims<4>& im2col_dims) { (void)im2col_data; // only used in optimized code. (void)im2col_dims; // only used in optimized code. const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); @@ -186,8 +186,9 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, for (int filter_y = 0; filter_y < filter_height; ++filter_y) { for (int filter_x = 0; filter_x < filter_width; ++filter_x) { for (int in_channel = 0; in_channel < input_depth; ++in_channel) { - const int in_x = in_x_origin + filter_x; - const int in_y = in_y_origin + filter_y; + const int in_x = in_x_origin + dilation_width_factor * filter_x; + const int in_y = + in_y_origin + dilation_height_factor * filter_y; // If the location is outside the bounds of the input image, // use zero as a default value. if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) && @@ -216,6 +217,23 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, } } +template +void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, int stride_width, + int stride_height, int dilation_width_factor, + int dilation_height_factor, int pad_width, int pad_height, + float* output_data, const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + float output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, + stride_width, stride_height, dilation_width_factor, + dilation_height_factor, pad_width, pad_height, output_activation_min, + output_activation_max, output_data, output_dims, im2col_data, + im2col_dims); +} + // legacy, for compatibility with old checked-in code template void Conv(const float* input_data, const Dims<4>& input_dims, @@ -227,7 +245,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims, float output_activation_min, output_activation_max; GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); Conv(input_data, input_dims, filter_data, filter_dims, bias_data, bias_dims, - stride_width, stride_height, pad_width, pad_height, + stride_width, stride_height, 1, 1, pad_width, pad_height, output_activation_min, output_activation_max, output_data, output_dims, im2col_data, im2col_dims); } @@ -241,7 +259,7 @@ void Conv(const float* input_data, const Dims<4>& input_dims, const Dims<4>& output_dims, float* im2col_data, const Dims<4>& im2col_dims) { Conv(input_data, input_dims, filter_data, filter_dims, bias_data, - bias_dims, stride, stride, pad_width, pad_height, output_data, + bias_dims, stride, stride, 1, 1, pad_width, pad_height, output_data, output_dims, im2col_data, im2col_dims); } @@ -1453,7 +1471,10 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, const Dims<4>& activ_temp_dims, int32 weights_zero_point, - int32 accum_multiplier, int accum_shift) { + int32 accum_multiplier, int accum_shift, + gemmlowp::GemmContext* gemm_context) { + (void)gemm_context; // only used in optimized code. + // Gather dimensions information, and perform consistency checks. const int batches = MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, @@ -1574,9 +1595,19 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, FS new_state = gemmlowp::SaturatingAdd( gemmlowp::Rescale(input_times_input_modulation), prev_state_times_forget_state); - // Implementation of last internal tanh node, still in fixed-point. - F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Implementation of last internal Tanh node, still in fixed-point. + // Since a Tanh fixed-point implementation is specialized for a given + // number or integer bits, and each specialization can have a substantial + // code size, and we already used above a Tanh on an input with 3 integer + // bits, and per the table in the above function comment there is no + // significant accuracy to be lost by clamping to [-8, +8] for a + // 3-integer-bits representation, let us just do that. This helps people + // porting this to targets where code footprint must be minimized. + F3 new_state_f3 = gemmlowp::Rescale<3>(new_state); + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state_f3); // Store the new internal state back to memory, as 16-bit integers. + // Note: here we store the original value with StateIntegerBits, not + // the rescaled 3-integer-bits value fed to tanh. output_state_data_int16[b * output_depth + c] = new_state.raw(); // Down-scale the output activations to 8-bit integers, saturating, // and store back to memory. @@ -1590,6 +1621,33 @@ void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, } } +template +void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, + int axis, int outputs_count, Scalar* const* output_data, + const Dims<4>* const* output_dims) { + const int batches = ArraySize(*output_dims[0], 3); + const int height = ArraySize(*output_dims[0], 2); + const int width = ArraySize(*output_dims[0], 1); + const int depth = ArraySize(*output_dims[0], 0); + + const int slice_size = ArraySize(*output_dims[0], axis); + + for (int i = 0; i < outputs_count; ++i) { + int offset = i * slice_size * input_dims.strides[axis]; + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + auto out = Offset(*output_dims[i], c, x, y, b); + auto in = Offset(input_dims, c, x, y, b); + output_data[i][out] = input_data[offset + in]; + } + } + } + } + } +} + template void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -1600,28 +1658,12 @@ void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, /* height = */ MatchingArraySize(*output_dims[i], 2, input_dims, 2); /* width = */ MatchingArraySize(*output_dims[i], 1, input_dims, 1); } - const int batches = MatchingArraySize(*output_dims[0], 3, input_dims, 3); - const int height = MatchingArraySize(*output_dims[0], 2, input_dims, 2); - const int width = MatchingArraySize(*output_dims[0], 1, input_dims, 1); // for now we dont have a model with a TensorFlowSplit // with fused activation function. TFLITE_DCHECK(Ac == FusedActivationFunctionType::kNone); - for (int b = 0; b < batches; ++b) { - for (int y = 0; y < height; ++y) { - for (int x = 0; x < width; ++x) { - int in_c = 0; - for (int i = 0; i < outputs_count; ++i) { - const int depth = ArraySize(*output_dims[i], 0); - for (int c = 0; c < depth; ++c) { - output_data[i][Offset(*output_dims[i], c, x, y, b)] = - input_data[Offset(input_dims, in_c, x, y, b)]; - in_c++; - } - } - TFLITE_DCHECK(in_c == ArraySize(input_dims, 0)); - } - } - } + + TensorFlowSplit(input_data, input_dims, /*axis=*/0, outputs_count, + output_data, output_dims); } // TODO(benoitjacob) make this a proper reference impl without Eigen! @@ -2192,6 +2234,41 @@ inline void Softmax(const uint8* input_data, const Dims<4>& input_dims, } } +inline void LogSoftmax(const float* input_data, const Dims<4>& input_dims, + float* output_data, const Dims<4>& output_dims) { + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + // Find max element value which we'll use to ensure numerical stability + // taking advantage of the following equality: + // log(exp(x[i])/sum(exp(x[i]))) == log(exp(x[i]+C)/sum(exp(x[i]+C))) + float max = std::numeric_limits::lowest(); + for (int c = 0; c < depth; ++c) { + max = std::max(max, input_data[Offset(input_dims, c, x, y, b)]); + } + + // Compute sum. + float sum = 0.f; + for (int c = 0; c < depth; ++c) { + sum += std::exp(input_data[Offset(input_dims, c, x, y, b)] - max); + } + + // Compute result. + const float log_sum = std::log(sum); + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + input_data[Offset(input_dims, c, x, y, b)] - max - log_sum; + } + } + } + } +} + inline void Logistic(const float* input_data, const Dims<4>& input_dims, float* output_data, const Dims<4>& output_dims) { const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); @@ -2762,6 +2839,14 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void Exp(const T* input_data, const size_t num_elements, + T* output_data) { + for (size_t idx = 0; idx < num_elements; ++idx) { + output_data[idx] = exp(input_data[idx]); + } +} + template inline void Mean(T* input_data, const int* input_dims, const int input_num_dims, T* output_data, const int* output_dims, @@ -2814,9 +2899,11 @@ inline void Mean(T* input_data, const int* input_dims, const int input_num_dims, for (int idx = 0; idx < num_resolved_axis; ++idx) { num_elements_in_axis *= static_cast(input_dims[resolved_axis[idx]]); } - for (size_t idx = 0; idx < num_outputs; ++idx) { - output_data[idx] = static_cast(static_cast(output_data[idx]) / - num_elements_in_axis); + if (num_elements_in_axis > 0) { + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = static_cast(static_cast(output_data[idx]) / + num_elements_in_axis); + } } } diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index dfe76c2afd40c692063710a4d98464b55e40feb9..62e38e0d4c3e023d0ed2242fc9438b096b86dc59 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -81,6 +81,51 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +// A list of tensors in a format that can be used by kernels like split and +// concatenation. +template +class VectorOfTensors { + public: + // Build with the tensors in 'tensor_list'. + VectorOfTensors(const TfLiteContext& context, + const TfLiteIntArray& tensor_list) { + int num_tensors = tensor_list.size; + + all_data_.reserve(num_tensors); + all_dims_.reserve(num_tensors); + all_dims_ptr_.reserve(num_tensors); + + for (int i = 0; i < num_tensors; ++i) { + TfLiteTensor* t = &context.tensors[tensor_list.data[i]]; + all_data_.push_back(GetTensorData(t)); + all_dims_.push_back(GetTensorDims(t)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_tensors; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + // Return a pointer to the data pointers of all tensors in the list. For + // example: + // float* const* f = v.data(); + // f[0][1] is the second element of the first tensor. + T* const* data() const { return all_data_.data(); } + + // Return a pointer the dim pointers of all tensors in the list. For + // example: + // const Dims<4>* const* d = v.dims(); + // dims[1] are the dimensions of the second tensor in the list. + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + } // namespace tflite #endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/log_softmax_test.cc b/tensorflow/contrib/lite/kernels/log_softmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..62820a2f5113cb6ae252386aaf3842135383b79f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/log_softmax_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite LOG_SOFTMAX op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +class LogSoftmaxOpModel : public SingleOpModel { + public: + LogSoftmaxOpModel(int batches, int size) + : batches_(batches), input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_LOG_SOFTMAX, BuiltinOptions_LogSoftmaxOptions, + CreateLogSoftmaxOptions(builder_).Union()); + BuildInterpreter({{batches_, input_size_}}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; + + int batches_; + int input_size_; +}; + +TEST(LogSoftmaxOpTest, SimpleTest) { + LogSoftmaxOpModel m(/*batches=*/2, /*size=*/5); + m.SetInput({ + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }); + + m.Invoke(); + + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-4.45191431, -3.45191431, -2.45191431, -1.45191443, -0.4519144, + -0.4519144, -1.45191443, -2.45191431, -3.45191431, -4.45191431}, + 1e-6))); +} + +TEST(LogSoftmaxOpTest, CompareWithTFmini) { + const int batch_size = 2; + const int input_size = 5; + static float input_buffer[] = { + 1.0, 2.0, 3.0, 4.0, 5.0, // b = 0 + -1.0, -2.0, -3.0, -4.0, -5.0, // b = 1 + }; + + LogSoftmaxOpModel m(batch_size, input_size); + + m.SetInput(0, input_buffer, input_buffer + input_size * batch_size); + + m.Invoke(); + + std::unique_ptr output_buffer(new float[input_size * batch_size]); + static tflite::Dims<4> input_dims = {{input_size, 1, 1, batch_size}, + {1, 0, 0, input_size}}; + tflite::reference_ops::LogSoftmax(input_buffer, input_dims, + output_buffer.get(), input_dims); + + std::vector expected; + expected.insert(expected.end(), output_buffer.get(), + output_buffer.get() + input_size * batch_size); + + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(expected, 1e-6))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc index c4c53c2ded351849e7c458fc754c36395a25ebd0..2d6d4bc2da4b75289ee27c3f2a12787216716d44 100644 --- a/tensorflow/contrib/lite/kernels/mean_test.cc +++ b/tensorflow/contrib/lite/kernels/mean_test.cc @@ -74,7 +74,7 @@ class MeanOpDynamicModel : public BaseMeanOpModel { } }; -TEST(ConstMeanOpTest, NotKeepDims) { +TEST(ConstFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -86,7 +86,7 @@ TEST(ConstMeanOpTest, NotKeepDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } -TEST(ConstMeanOpTest, KeepDims) { +TEST(ConstFloatMeanOpTest, KeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -99,7 +99,7 @@ TEST(ConstMeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } -TEST(DynamicMeanOpTest, NotKeepDims) { +TEST(DynamicFloatMeanOpTest, NotKeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -114,7 +114,7 @@ TEST(DynamicMeanOpTest, NotKeepDims) { EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); } -TEST(DynamicMeanOpTest, KeepDims) { +TEST(DynamicFloatMeanOpTest, KeepDims) { std::initializer_list data = { 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; @@ -130,6 +130,70 @@ TEST(DynamicMeanOpTest, KeepDims) { ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); } +TEST(DynamicFloatMeanOpTest, Scale) { + std::initializer_list data = {9.527}; + MeanOpDynamicModel m({TensorType_FLOAT32, {1}}, {TensorType_FLOAT32, {1}}, + {TensorType_INT32, {1}}, true); + std::initializer_list axis = {0}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({9.527}))); +} + +TEST(ConstUint8MeanOpTest, NotKeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({12, 13})); +} + +TEST(ConstUint8MeanOpTest, KeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpConstModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 12, 14})); +} + +TEST(DynamicUint8MeanOpTest, NotKeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {2}}, + {TensorType_INT32, {4}}, false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({12, 13})); +} + +TEST(DynamicUint8MeanOpTest, KeepDims) { + std::initializer_list data = {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24}; + MeanOpDynamicModel m({TensorType_UINT8, {4, 3, 2}}, {TensorType_UINT8, {3}}, + {TensorType_INT32, {2}}, true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({10, 12, 14})); +} + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 1fb779fd5174a255b7d34322c57dc084f68d8c3f..aea6f8d9d34420363cc1045425f3d27b12af449e 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -49,6 +49,7 @@ TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_L2_NORMALIZATION(); TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); TfLiteRegistration* Register_PAD(); TfLiteRegistration* Register_RESHAPE(); @@ -58,8 +59,12 @@ TfLiteRegistration* Register_SPACE_TO_DEPTH(); TfLiteRegistration* Register_GATHER(); TfLiteRegistration* Register_TRANSPOSE(); TfLiteRegistration* Register_MEAN(); +TfLiteRegistration* Register_SPLIT(); TfLiteRegistration* Register_SQUEEZE(); TfLiteRegistration* Register_STRIDED_SLICE(); +TfLiteRegistration* Register_EXP(); +TfLiteRegistration* Register_TOPK_V2(); +TfLiteRegistration* Register_LOG_SOFTMAX(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -94,6 +99,8 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, + Register_BIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, Register_UNIDIRECTIONAL_SEQUENCE_LSTM()); AddBuiltin(BuiltinOperator_PAD, Register_PAD()); @@ -106,8 +113,12 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_MEAN, Register_MEAN()); AddBuiltin(BuiltinOperator_DIV, Register_DIV()); AddBuiltin(BuiltinOperator_SUB, Register_SUB()); + AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT()); AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); + AddBuiltin(BuiltinOperator_EXP, Register_EXP()); + AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); + AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/split.cc b/tensorflow/contrib/lite/kernels/split.cc new file mode 100644 index 0000000000000000000000000000000000000000..b524c79f8779b0119781679c0af9fe354e38ad4f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/split.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace split { + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + axis = GetInput(context, node, 0); + input = GetInput(context, node, 1); + } + TfLiteSplitParams* params; + TfLiteTensor* axis; + TfLiteTensor* input; +}; + +TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { + for (int i = 0; i < NumOutputs(node); ++i) { + SetTensorToDynamic(GetOutput(context, node, i)); + } + return kTfLiteOk; +} + +TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, + TfLiteTensor* axis, TfLiteTensor* input, + int num_splits) { + int axis_value = GetTensorData(axis)[0]; + if (axis_value < 0) { + axis_value += NumDimensions(input); + } + + const int input_size = SizeOfDimension(input, axis_value); + TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0, + "Not an even split"); + const int slice_size = input_size / num_splits; + + for (int i = 0; i < NumOutputs(node); ++i) { + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); + output_dims->data[axis_value] = slice_size; + TfLiteTensor* output = GetOutput(context, node, i); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims)); + } + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + + OpContext op_context(context, node); + + TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); + + auto input_type = op_context.input->type; + TF_LITE_ENSURE(context, + input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8); + for (int i = 0; i < NumOutputs(node); ++i) { + GetOutput(context, node, i)->type = input_type; + } + + // If we know the contents of the 'axis' tensor, resize all outputs. + // Otherwise, wait until Eval(). + if (IsConstantTensor(op_context.axis)) { + return ResizeOutputTensors(context, node, op_context.axis, op_context.input, + op_context.params->num_splits); + } else { + return UseDynamicOutputTensors(context, node); + } +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + // When the 'axis' tensor is non-const we can't resize output tensors in + // Prepare(), and we have to do it now. + if (!IsConstantTensor(op_context.axis)) { + TF_LITE_ENSURE_OK( + context, + ResizeOutputTensors(context, node, op_context.axis, op_context.input, + op_context.params->num_splits)); + } + + int axis_value = GetTensorData(op_context.axis)[0]; + if (axis_value < 0) { + axis_value += NumDimensions(op_context.input); + } + axis_value = RemapDim(NumDimensions(op_context.input), axis_value); + + // TODO(ahentz): Our usage of VectorOfTensors could be optimized by + // calculating it in Prepare, unless we defer shape calculation. + // TODO(ahentz): We can improve the optimized_ops version to handle other + // cases too. +#define TF_LITE_SPLIT(scalar) \ + VectorOfTensors all_outputs(*context, *node->outputs); \ + if (axis_value == NumDimensions(op_context.input)) { \ + optimized_ops::TensorFlowSplit( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), NumOutputs(node), all_outputs.data(), \ + all_outputs.dims()); \ + } else { \ + reference_ops::TensorFlowSplit( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), axis_value, NumOutputs(node), \ + all_outputs.data(), all_outputs.dims()); \ + } + switch (op_context.input->type) { + case kTfLiteFloat32: { + TF_LITE_SPLIT(float); + break; + } + case kTfLiteUInt8: { + TF_LITE_SPLIT(uint8_t); + break; + } + default: + context->ReportError(context, + "Only float32 and uint8 are currently supported."); + return kTfLiteError; + } +#undef TF_LITE_SPLIT + + return kTfLiteOk; +} + +} // namespace split + +TfLiteRegistration* Register_SPLIT() { + static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/split_test.cc b/tensorflow/contrib/lite/kernels/split_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..61a0759c6475795c06a9b55d3586d2b818f298b2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/split_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +constexpr int kAxisIsATensor = -1000; + +class SplitOpModel : public SingleOpModel { + public: + SplitOpModel(const TensorData& input, int num_splits, + int axis = kAxisIsATensor) { + if (axis == kAxisIsATensor) { + axis_ = AddInput({TensorType_INT32, {1}}); + } else { + axis_ = AddConstInput(TensorType_INT32, {axis}, {1}); + } + input_ = AddInput(input); + for (int i = 0; i < num_splits; ++i) { + outputs_.push_back(AddOutput(input.type)); + } + SetBuiltinOp(BuiltinOperator_SPLIT, BuiltinOptions_SplitOptions, + CreateSplitOptions(builder_, num_splits).Union()); + if (axis == kAxisIsATensor) { + BuildInterpreter({GetShape(axis_), GetShape(input_)}); + } else { + BuildInterpreter({{}, GetShape(input_)}); + } + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetAxis(int axis) { PopulateTensor(axis_, {axis}); } + + std::vector GetOutput(int i) { + return ExtractVector(outputs_[i]); + } + std::vector GetOutputShape(int i) { return GetTensorShape(outputs_[i]); } + + private: + int input_; + int axis_; + std::vector outputs_; +}; + +using TensorValues = std::initializer_list; + +void Check(int axis, int num_splits, std::initializer_list input_shape, + std::initializer_list output_shape, + const TensorValues& input_data, + const std::vector& output_data) { + auto debug = [&](int i) { + std::stringstream ss; + ss << "for output tensor " << i << " axis=" << axis + << " and num_splits=" << num_splits; + return ss.str(); + }; + SplitOpModel m({TensorType_FLOAT32, input_shape}, num_splits); + m.SetInput(input_data); + m.SetAxis(axis); + m.Invoke(); + for (int i = 0; i < num_splits; ++i) { + EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])) << debug(i); + EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape)) + << debug(i); + } + + SplitOpModel const_m({TensorType_FLOAT32, input_shape}, num_splits, axis); + const_m.SetInput(input_data); + const_m.Invoke(); + for (int i = 0; i < num_splits; ++i) { + EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i])) + << debug(i); + EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape)) + << debug(i); + } +} + +TEST(SplitOpTest, FourDimensional) { + Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); + Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 9, 10, 11, 12}, + {5, 6, 7, 8, 13, 14, 15, 16}, + }); + Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 5, 6, 9, 10, 13, 14}, + {3, 4, 7, 8, 11, 12, 15, 16}, + }); + Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 3, 5, 7, 9, 11, 13, 15}, + {2, 4, 6, 8, 10, 12, 14, 16}, + }); +} + +TEST(SplitOpTest, OneDimensional) { + Check(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8}, + {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}}); +} + +TEST(SplitOpTest, NegativeAxis) { + Check(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + { + {1, 2, 3, 4, 5, 6, 7, 8}, + {9, 10, 11, 12, 13, 14, 15, 16}, + }); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index 6f56aa6bf38781e860e33e8ac3b6a0bb8b50bb01..373310bd87370a670a847cf5328633956028a850 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -187,6 +187,7 @@ void SingleOpModel::BuildInterpreter( for (const auto& shape : input_shapes) { int input_idx = interpreter_->inputs()[i++]; if (input_idx == kOptionalTensor) continue; + if (shape.empty()) continue; CHECK(interpreter_->ResizeInputTensor(input_idx, shape) == kTfLiteOk); } CHECK(interpreter_->AllocateTensors() == kTfLiteOk) diff --git a/tensorflow/contrib/lite/kernels/topk_v2.cc b/tensorflow/contrib/lite/kernels/topk_v2.cc new file mode 100644 index 0000000000000000000000000000000000000000..807e84609f8b23d25324d99d26086331d78a0684 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/topk_v2.cc @@ -0,0 +1,232 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +namespace tflite { +namespace ops { +namespace builtin { +namespace topk_v2 { +constexpr int kInputTensor = 0; +constexpr int kInputTopK = 1; +constexpr int kOutputIndexes = 0; +constexpr int kOutputValues = 1; + +namespace { +TfLiteStatus ResizeOutput(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + // INT32 number of top results is supported. + TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); + // Check that the tensor contains only one value. + TF_LITE_ENSURE_EQ(context, NumDimensions(top_k), 1); + TF_LITE_ENSURE_EQ(context, NumElements(top_k), 1); + const int32 k = top_k->data.i32[0]; + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + const int num_dimensions = NumDimensions(input); + // Check that input has one or more dimensions. + TF_LITE_ENSURE_MSG(context, input->dims->size >= 1, + "TopK k input must have 1 or more dimensions."); + // Check that k is less or equal the internal dimension. + TF_LITE_ENSURE_MSG(context, k <= input->dims->data[num_dimensions - 1], + "TopK k is higher than the internal dimension."); + + TfLiteIntArray* output_indexes_shape = TfLiteIntArrayCreate(num_dimensions); + TfLiteIntArray* output_values_shape = TfLiteIntArrayCreate(num_dimensions); + for (int i = 0; i < num_dimensions - 1; ++i) { + output_indexes_shape->data[i] = input->dims->data[i]; + output_values_shape->data[i] = input->dims->data[i]; + } + output_indexes_shape->data[num_dimensions - 1] = k; + output_values_shape->data[num_dimensions - 1] = k; + TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + auto resize_tensor = [context](TfLiteTensor* tensor, TfLiteIntArray* new_size, + TfLiteIntArray* delete_on_error) { + TfLiteStatus status = context->ResizeTensor(context, tensor, new_size); + if (status != kTfLiteOk) { + TfLiteIntArrayFree(new_size); + if (delete_on_error != nullptr) { + TfLiteIntArrayFree(delete_on_error); + } + } + return status; + }; + TF_LITE_ENSURE_OK(context, resize_tensor(output_indexes, output_indexes_shape, + output_values_shape)); + TF_LITE_ENSURE_OK(context, + resize_tensor(output_values, output_values_shape, nullptr)); + return kTfLiteOk; +} + +// The class that collects top indexes of k values. Based on template +// tensorflow::gtl::TopN<> but, for optimization, +// it re-uses the same container. +template +class TopContainer { + public: + TopContainer() = delete; + TopContainer(int32 k, int32 row_size) : k_(k) { + container_.reserve(std::min(k, row_size) + 1); + } + + void start_collecting(const T* values) { + values_ = values; + container_.clear(); + } + void push(int32 a) { + auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; + if (container_.size() <= k_) { + container_.push_back(a); + if (container_.size() == k_ + 1) { + std::make_heap(container_.begin(), container_.end(), comparator); + std::pop_heap(container_.begin(), container_.end(), comparator); + } + } else if (comparator(a, container_.front())) { + container_.back() = a; + std::push_heap(container_.begin(), container_.end(), comparator); + std::pop_heap(container_.begin(), container_.end(), comparator); + } + } + + const std::vector& sorted_result() { + auto comparator = [this](int32 a, int32 b) { return compare_fun(a, b); }; + if (container_.size() <= k_) { + std::sort(container_.begin(), container_.end(), comparator); + } else { + std::sort_heap(container_.begin(), container_.end() - 1, comparator); + container_.resize(k_); + } + return container_; + } + + private: + int32 k_; + std::vector container_; + const T* values_ = nullptr; + + bool compare_fun(int32 a, int32 b) const { + if (values_[b] < values_[a]) { + return true; + } else if (values_[b] > values_[a]) { + return false; + } else { + return a < b; + } + } +}; + +// Mostly modeled on tensorflow/core/kernels/topk_op.cc for CPU. +template +void TopK(int32 row_size, int32 num_rows, const T* data, int32 k, + int32* output_indexes, T* output_values) { + TopContainer topc(k, row_size); + for (int row = 0; row < num_rows; ++row) { + const T* values_row = data + row * row_size; + topc.start_collecting(values_row); + for (int32 c = 0; c < row_size; ++c) { + topc.push(c); + } + + // Prepare output buffers. + int32* indexes_row = output_indexes + row * k; + T* output_row = output_values + row * k; + // We always assume that the output is sorted. + const auto& top_k = topc.sorted_result(); + std::copy(top_k.begin(), top_k.end(), indexes_row); + std::transform(top_k.begin(), top_k.end(), output_row, + [values_row](const int32 loc) { return values_row[loc]; }); + } +} + +} // namespace + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check that the inputs and outputs have the right sizes and types. + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 2); + + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + TF_LITE_ENSURE_EQ(context, input->type, output_values->type); + + TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + TF_LITE_ENSURE_EQ(context, top_k->type, kTfLiteInt32); + + // Set output dynamic if the input is not const. + if (IsConstantTensor(top_k)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } else { + TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + SetTensorToDynamic(output_indexes); + SetTensorToDynamic(output_values); + } + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* output_values = GetOutput(context, node, kOutputValues); + TfLiteTensor* output_indexes = GetOutput(context, node, kOutputIndexes); + if (IsDynamicTensor(output_values)) { + TF_LITE_ENSURE_OK(context, ResizeOutput(context, node)); + } + TfLiteTensor* top_k = GetInput(context, node, kInputTopK); + const int32 k = top_k->data.i32[0]; + // The tensor can have more than 2 dimensions or even be a vector, the code + // anyway calls the internal dimension as row; + TfLiteTensor* input = GetInput(context, node, kInputTensor); + const int32 row_size = input->dims->data[input->dims->size - 1]; + int32 num_rows = 1; + for (int i = 0; i < input->dims->size - 1; ++i) { + num_rows *= input->dims->data[i]; + } + switch (output_values->type) { + case kTfLiteFloat32: + TopK(row_size, num_rows, input->data.f, k, output_indexes->data.i32, + output_values->data.f); + break; + case kTfLiteUInt8: + TopK(row_size, num_rows, input->data.uint8, k, output_indexes->data.i32, + output_values->data.uint8); + break; + case kTfLiteInt32: + TopK(row_size, num_rows, input->data.i32, k, output_indexes->data.i32, + output_values->data.i32); + break; + case kTfLiteInt64: + TopK(row_size, num_rows, input->data.i64, k, output_indexes->data.i32, + output_values->data.i64); + break; + default: + context->ReportError(context, "Type is currently not supported by TopK."); + return kTfLiteError; + } + + return kTfLiteOk; +} +} // namespace topk_v2 +TfLiteRegistration* Register_TOPK_V2() { + static TfLiteRegistration r = {nullptr, nullptr, topk_v2::Prepare, + topk_v2::Eval}; + return &r; +} +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/topk_v2_test.cc b/tensorflow/contrib/lite/kernels/topk_v2_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..29f2a057cd45e1cded3ff1aa0f0fdcad666ce2fa --- /dev/null +++ b/tensorflow/contrib/lite/kernels/topk_v2_test.cc @@ -0,0 +1,155 @@ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class TopKV2OpModel : public SingleOpModel { + public: + TopKV2OpModel(std::initializer_list input_shape, TensorType input_type, + int top_k) { + input_ = AddInput(input_type); + top_k_ = AddInput(TensorType_INT32); + output_indexes_ = AddOutput(TensorType_INT32); + output_values_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_TOPK_V2, BuiltinOptions_TopKV2Options, 0); + BuildInterpreter({input_shape, {1}}); + PopulateTensor(top_k_, {top_k}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUInt8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt32(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputInt64(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetIndexes() { + return ExtractVector(output_indexes_); + } + + std::vector GetValuesFloat() { + return ExtractVector(output_values_); + } + + std::vector GetValuesUInt8() { + return ExtractVector(output_values_); + } + + std::vector GetValuesInt32() { + return ExtractVector(output_values_); + } + + std::vector GetValuesInt64() { + return ExtractVector(output_values_); + } + + protected: + int input_; + int top_k_; + int output_indexes_; + int output_values_; +}; + +// The test where the tensor dimension is equal to top. +TEST(TopKV2OpTest, EqualFloat) { + TopKV2OpModel m({2, 2}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, 0.2, 0.8, 0.1}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({1, 0, 0, 1})); + EXPECT_THAT(m.GetValuesFloat(), + ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1}))); +} + +// Test when internal dimension is k+1. +TEST(TopKV2OpTest, BorderFloat) { + TopKV2OpModel m({2, 3}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, -3.0, 0.2, 0.8, 0.1, -0.1}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 0, 0, 1})); + EXPECT_THAT(m.GetValuesFloat(), + ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1}))); +} +// Test when internal dimension is higher than k. +TEST(TopKV2OpTest, LargeFloat) { + TopKV2OpModel m({2, 4}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({3, 0, 0, 1})); + EXPECT_THAT(m.GetValuesFloat(), + ElementsAreArray(ArrayFloatNear({0.2, -2.0, 0.8, 0.1}))); +} + +// Test 1D case. +TEST(TopKV2OpTest, VectorFloat) { + TopKV2OpModel m({8}, TensorType_FLOAT32, 2); + m.SetInputFloat({-2.0, -3.0, -4.0, 0.2, 0.8, 0.1, -0.1, -0.8}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({4, 3})); + EXPECT_THAT(m.GetValuesFloat(), ElementsAreArray(ArrayFloatNear({0.8, 0.2}))); +} + +// Check that uint8 works. +TEST(TopKV2OpTest, TypeUint8) { + TopKV2OpModel m({2, 3}, TensorType_UINT8, 2); + m.SetInputUInt8({1, 2, 3, 251, 250, 249}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1})); + EXPECT_THAT(m.GetValuesUInt8(), ElementsAreArray({3, 2, 251, 250})); +} + +// Check that int32 works. +TEST(TopKV2OpTest, TypeInt32) { + TopKV2OpModel m({2, 3}, TensorType_INT32, 2); + m.SetInputInt32({1, 2, 3, 10251, 10250, 10249}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1})); + EXPECT_THAT(m.GetValuesInt32(), ElementsAreArray({3, 2, 10251, 10250})); +} + +// Check that int64 works. +TEST(TopKV2OpTest, TypeInt64) { + TopKV2OpModel m({2, 3}, TensorType_INT64, 2); + m.SetInputInt64({1, 2, 3, -1, -2, -3}); + m.Invoke(); + EXPECT_THAT(m.GetIndexes(), ElementsAreArray({2, 1, 0, 1})); + EXPECT_THAT(m.GetValuesInt64(), ElementsAreArray({3, 2, -1, -2})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 14b6709964b54a6532273a69cca51c560b1cc103..725f2838c574fcc2ba389401f92575279ebc144c 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -124,19 +124,25 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() { auto opcodes = model_->operator_codes(); for (const OperatorCode* opcode : *opcodes) { TfLiteRegistration* registration = nullptr; - - if (opcode->builtin_code() != BuiltinOperator_CUSTOM) { - auto x = opcode->builtin_code(); - flatbuffer_op_index_to_registration_types_.push_back(x); - registration = op_resolver_.FindOp(x); + auto builtin_code = opcode->builtin_code(); + if (builtin_code > BuiltinOperator_MAX || + builtin_code < BuiltinOperator_MIN) { + error_reporter_->Report( + "Op builtin_code out or range: %d. Are you using old TFLite binary " + "with newer model?", + builtin_code); + status = kTfLiteError; + } else if (builtin_code != BuiltinOperator_CUSTOM) { + flatbuffer_op_index_to_registration_types_.push_back(builtin_code); + registration = op_resolver_.FindOp(builtin_code); if (registration == nullptr) { error_reporter_->Report("Didn't find op for builtin opcode '%s'\n", - EnumNameBuiltinOperator(x)); + EnumNameBuiltinOperator(builtin_code)); status = kTfLiteError; } } else if (!opcode->custom_code()) { error_reporter_->Report( - "Operator with builtin_code==0 has no custom_code.\n"); + "Operator with CUSTOM builtin_code has no custom_code.\n"); status = kTfLiteError; } else { const char* name = opcode->custom_code()->c_str(); @@ -278,6 +284,9 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_RELU_N1_TO_1: case BuiltinOperator_RELU6: case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_EXP: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_LOG_SOFTMAX: break; case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = @@ -453,6 +462,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { TfLiteLSTMParams* params = MallocPOD(); @@ -533,6 +543,14 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_SPLIT: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SplitOptions()) { + params->num_splits = schema_params->num_splits(); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_SQUEEZE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { @@ -556,6 +574,11 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_DELEGATE: { + // TODO(ycling): Revisit when supporting saving delegated models. + error_reporter->Report("DELEGATE op shouldn't exist in model."); + break; + } } return builtin_data; } @@ -769,6 +792,8 @@ TfLiteStatus InterpreterBuilder::operator()( return cleanup_and_error(); } + (**interpreter).set_model(model_); + // Parse inputs/outputs (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs())); (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs())); diff --git a/tensorflow/contrib/lite/models/speech_test.cc b/tensorflow/contrib/lite/models/speech_test.cc index daa8c3100b64e9290256aa14a6ab641f19174a0a..a354179a9480c136d65f83836d81f69c2089fdbe 100644 --- a/tensorflow/contrib/lite/models/speech_test.cc +++ b/tensorflow/contrib/lite/models/speech_test.cc @@ -97,7 +97,12 @@ bool ConvertCsvData(const string& model_name, const string& in_name, return true; } -TEST(SpeechTest, HotwordOkGoogleRank1Test) { +class SpeechTest : public ::testing::TestWithParam { + protected: + int GetMaxInvocations() { return GetParam(); } +}; + +TEST_P(SpeechTest, HotwordOkGoogleRank1Test) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_hotword_model_rank1.tflite", "speech_hotword_model_in.csv", @@ -105,11 +110,11 @@ TEST(SpeechTest, HotwordOkGoogleRank1Test) { /*output_tensor=*/"18", /*persistent_tensors=*/"4", /*sequence_size=*/40, &os)); testing::TfLiteDriver test_driver(/*use_nnapi=*/false); - ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } -TEST(SpeechTest, HotwordOkGoogleRank2Test) { +TEST_P(SpeechTest, HotwordOkGoogleRank2Test) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_hotword_model_rank2.tflite", "speech_hotword_model_in.csv", @@ -117,11 +122,11 @@ TEST(SpeechTest, HotwordOkGoogleRank2Test) { /*output_tensor=*/"18", /*persistent_tensors=*/"1", /*sequence_size=*/40, &os)); testing::TfLiteDriver test_driver(/*use_nnapi=*/false); - ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } -TEST(SpeechTest, SpeakerIdOkGoogleTest) { +TEST_P(SpeechTest, SpeakerIdOkGoogleTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_speakerid_model.tflite", "speech_speakerid_model_in.csv", @@ -130,11 +135,11 @@ TEST(SpeechTest, SpeakerIdOkGoogleTest) { /*persistent_tensors=*/"19,20,40,41,61,62", /*sequence_size=*/80, &os)); testing::TfLiteDriver test_driver(/*use_nnapi=*/false); - ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } -TEST(SpeechTest, AsrAmTest) { +TEST_P(SpeechTest, AsrAmTest) { std::stringstream os; ASSERT_TRUE( ConvertCsvData("speech_asr_am_model.tflite", "speech_asr_am_model_in.csv", @@ -143,7 +148,7 @@ TEST(SpeechTest, AsrAmTest) { /*persistent_tensors=*/"19,20,40,41,61,62,82,83,103,104", /*sequence_size=*/320, &os)); testing::TfLiteDriver test_driver(/*use_nnapi=*/false); - ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } @@ -151,15 +156,16 @@ TEST(SpeechTest, AsrAmTest) { // through the interpreter and stored the sum of all the output, which was them // compared for correctness. In this test we are comparing all the intermediate // results. -TEST(SpeechTest, AsrLmTest) { +TEST_P(SpeechTest, AsrLmTest) { std::ifstream in_file; testing::TfLiteDriver test_driver(/*use_nnapi=*/false); ASSERT_TRUE(Init("speech_asr_lm_model.test_spec", &test_driver, &in_file)); - ASSERT_TRUE(testing::ParseAndRunTests(&in_file, &test_driver)) + ASSERT_TRUE( + testing::ParseAndRunTests(&in_file, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } -TEST(SpeechTest, EndpointerTest) { +TEST_P(SpeechTest, EndpointerTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData( "speech_endpointer_model.tflite", "speech_endpointer_model_in.csv", @@ -168,11 +174,11 @@ TEST(SpeechTest, EndpointerTest) { /*persistent_tensors=*/"28,29,49,50", /*sequence_size=*/320, &os)); testing::TfLiteDriver test_driver(/*use_nnapi=*/false); - ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } -TEST(SpeechTest, TtsTest) { +TEST_P(SpeechTest, TtsTest) { std::stringstream os; ASSERT_TRUE(ConvertCsvData("speech_tts_model.tflite", "speech_tts_model_in.csv", @@ -181,9 +187,19 @@ TEST(SpeechTest, TtsTest) { /*persistent_tensors=*/"25,26,46,47,67,68,73", /*sequence_size=*/334, &os)); testing::TfLiteDriver test_driver(/*use_nnapi=*/false); - ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver)) + ASSERT_TRUE(testing::ParseAndRunTests(&os, &test_driver, GetMaxInvocations())) << test_driver.GetErrorMessage(); } +// Define two instantiations. The "ShortTests" instantiations is used when +// running the tests on Android, in order to prevent timeouts (It takes about +// 200s just to bring up the Android emulator.) +static const int kAllInvocations = -1; +static const int kFirstFewInvocations = 10; +INSTANTIATE_TEST_CASE_P(LongTests, SpeechTest, + ::testing::Values(kAllInvocations)); +INSTANTIATE_TEST_CASE_P(ShortTests, SpeechTest, + ::testing::Values(kFirstFewInvocations)); + } // namespace } // namespace tflite diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index da9ceec2f1401745ba477824bf494ee5b0ee1187..e631ffd845d3b31232070b935c12aa8a2e8ce05e 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -323,6 +323,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: case tflite::BuiltinOperator_EMBEDDING_LOOKUP: case tflite::BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: + case tflite::BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case tflite::BuiltinOperator_L2_NORMALIZATION: case tflite::BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: @@ -335,12 +336,17 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_GATHER: case tflite::BuiltinOperator_SPACE_TO_BATCH_ND: case tflite::BuiltinOperator_BATCH_TO_SPACE_ND: + case tflite::BuiltinOperator_TOPK_V2: case tflite::BuiltinOperator_TRANSPOSE: case tflite::BuiltinOperator_MEAN: case tflite::BuiltinOperator_DIV: case tflite::BuiltinOperator_SUB: + case tflite::BuiltinOperator_SPLIT: case tflite::BuiltinOperator_SQUEEZE: case tflite::BuiltinOperator_STRIDED_SLICE: + case tflite::BuiltinOperator_EXP: + case tflite::BuiltinOperator_LOG_SOFTMAX: + case tflite::BuiltinOperator_DELEGATE: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD index 2d8c49b7d7a5ae5c180f100a399a1870679c455f..82feae0f0041997949212613c654a5695f468d56 100644 --- a/tensorflow/contrib/lite/python/BUILD +++ b/tensorflow/contrib/lite/python/BUILD @@ -28,6 +28,7 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/contrib/framework:framework_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", ], ) diff --git a/tensorflow/contrib/lite/python/op_hint.py b/tensorflow/contrib/lite/python/op_hint.py index 7c587e38b16dc3011fc7c8bef4eec4d0ea99ec21..9a3971228a683211e84b4c55d3a3e8d574b5ed94 100644 --- a/tensorflow/contrib/lite/python/op_hint.py +++ b/tensorflow/contrib/lite/python/op_hint.py @@ -73,6 +73,7 @@ import itertools as _itertools import uuid as _uuid from tensorflow.contrib import framework as _framework +from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 from tensorflow.python.framework import ops as _ops from tensorflow.python.ops import array_ops as _array_ops from tensorflow.python.util.all_util import remove_undocumented @@ -133,10 +134,17 @@ class OpHint(object): def augmented_identity(arg): identity_op = _array_ops.identity(arg) - attr = identity_op.op.node_def.attr - attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name - attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id - attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i = self._curr_input_index + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue(s=self._function_name)) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue(s=self._unique_function_id)) + identity_op.op._set_attr( + OpHint.FUNCTION_INPUT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=self._curr_input_index)) + # pylint: enable=protected-access self._curr_input_index += 1 return identity_op @@ -154,10 +162,17 @@ class OpHint(object): def augmented_identity(arg): identity_op = _array_ops.identity(arg) - attr = identity_op.op.node_def.attr - attr[OpHint.FUNCTION_NAME_ATTR].s = self._function_name - attr[OpHint.FUNCTION_UUID_ATTR].s = self._unique_function_id - attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i = self._curr_output_index + # pylint: disable=protected-access + identity_op.op._set_attr( + OpHint.FUNCTION_NAME_ATTR, + _attr_value_pb2.AttrValue(s=self._function_name)) + identity_op.op._set_attr( + OpHint.FUNCTION_UUID_ATTR, + _attr_value_pb2.AttrValue(s=self._unique_function_id)) + identity_op.op._set_attr( + OpHint.FUNCTION_OUTPUT_INDEX_ATTR, + _attr_value_pb2.AttrValue(i=self._curr_output_index)) + # pylint: enable=protected-access self._curr_output_index += 1 return identity_op diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..0148149a6adc141d67e82808f7e8c72ddb7e309a --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/BUILD @@ -0,0 +1,43 @@ +package(default_visibility = [ + "//visibility:public", +]) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "generator", + srcs = ["generator.cc"], + hdrs = ["generator.h"], + deps = [ + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +cc_binary( + name = "generate", + srcs = ["generate.cc"], + deps = [ + ":generator", + ], +) + +cc_test( + name = "generator_test", + srcs = ["generator_test.cc"], + deps = [ + ":generator", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "consistency_test", + srcs = ["consistency_test.cc"], + data = [ + "//tensorflow/contrib/lite:builtin_ops.h", + ], + deps = [ + ":generator", + "@com_google_googletest//:gtest", + ], +) diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/README.md b/tensorflow/contrib/lite/schema/builtin_ops_header/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f20d4f664e62fdd52e55339e45b9603307a2b671 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/README.md @@ -0,0 +1,12 @@ +# Builtin Ops Header Generator. + +This directory contains a code generator to generate a pure C header for +builtin op definition. + +Whenever you add a new builtin op, please execute: + +```sh +bazel run \ + //tensorflow/contrib/lite/schema/builtin_ops_header:generate > \ + tensorflow/contrib/lite/builtin_ops.h +``` diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d55c125c117db3c1b8d67ab0b674abe2e7c39d94 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/consistency_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" + +namespace { + +const char* kHeaderFileName = + "tensorflow/contrib/lite/builtin_ops.h"; + +// The test ensures that `builtin_ops.h` is consistent with the FlatBuffer +// schema definition. When the schema is modified, it's required to run the +// generator to re-generate the header. +// Please see README.md for more details. +TEST(BuiltinOpsHeaderTest, TestConsistency) { + std::ifstream input_stream(kHeaderFileName, std::ios::binary); + ASSERT_TRUE(input_stream); + std::string file_content((std::istreambuf_iterator(input_stream)), + std::istreambuf_iterator()); + + std::ostringstream output_stream; + tflite::builtin_ops_header::GenerateHeader(output_stream); + std::string generated_content = output_stream.str(); + + EXPECT_EQ(file_content, generated_content); +} + +} // anonymous namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc new file mode 100644 index 0000000000000000000000000000000000000000..72a28987b8d4863b0f03f7861177940177edd884 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generate.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" + +// This executable is used to generate builtin_ops.h in TensorFlow Lite. +// Please see README.md for more details. +int main() { + if (!tflite::builtin_ops_header::GenerateHeader(std::cout)) { + std::cerr << "Failed to generate the header file.\n"; + } + return 0; +} diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc new file mode 100644 index 0000000000000000000000000000000000000000..08bcfe451685f488be2c3bc180f2dfc43dfe4f05 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" + +namespace tflite { +namespace builtin_ops_header { + +namespace { +const char* kFileHeader = + R"(/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ + +// DO NOT EDIT MANUALLY: This file is automatically generated by +// `schema_builtin_ops_header_generator.py`. + +#ifdef __cplusplus +extern "C" { +#endif // __cplusplus + +// The enum for builtin operators. +// Note: CUSTOM and DELEGATE are 2 special ops which are not real biultin +// ops. +typedef enum { +)"; + +const char* kFileFooter = + R"(} TfLiteBuiltinOperator; + +#ifdef __cplusplus +} // extern "C" +#endif // __cplusplus +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OPS_H_ +} +)"; +} // anonymous namespace + +bool IsValidInputEnumName(const std::string& name) { + const char* begin = name.c_str(); + const char* ch = begin; + while (*ch != '\0') { + // If it's not the first character, expect an underscore. + if (ch != begin) { + if (*ch != '_') { + return false; + } + ++ch; + } + + // Expecting a word with upper case letters or digits, like "CONV", + // "CONV2D", "2D"...etc. + bool empty = true; + while (isupper(*ch) || isdigit(*ch)) { + // It's not empty if at least one character is consumed. + empty = false; + ++ch; + } + if (empty) { + return false; + } + } + return true; +} + +std::string ConstantizeVariableName(const std::string& name) { + std::string result = "kTfLiteBuiltin"; + bool uppercase = true; + for (char input_char : name) { + if (input_char == '_') { + uppercase = true; + } else if (uppercase) { + result += toupper(input_char); + uppercase = false; + } else { + result += tolower(input_char); + } + } + + return result; +} + +bool GenerateHeader(std::ostream& os) { + auto enum_names = tflite::EnumNamesBuiltinOperator(); + + // Check if all the input enum names are valid. + for (auto enum_value : EnumValuesBuiltinOperator()) { + auto enum_name = enum_names[enum_value]; + if (!IsValidInputEnumName(enum_name)) { + std::cerr << "Invalid input enum name: " << enum_name << std::endl; + return false; + } + } + + os << kFileHeader; + for (auto enum_value : EnumValuesBuiltinOperator()) { + auto enum_name = enum_names[enum_value]; + os << " "; + os << ConstantizeVariableName(enum_name); + os << " = "; + os << enum_value; + os << ",\n"; + } + os << kFileFooter; + return true; +} + +} // namespace builtin_ops_header +} // namespace tflite diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h new file mode 100644 index 0000000000000000000000000000000000000000..3241ff83d599ed8a476fc1d5a88c26143ebfbaf2 --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator.h @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// An utility library to generate pure C header for builtin ops definition. +#ifndef TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ +#define TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ + +#include + +namespace tflite { +namespace builtin_ops_header { + +// Check if the input enum name (from the Flatbuffer definition) is valid. +bool IsValidInputEnumName(const std::string& name); + +// Convert the enum name from Flatbuffer convention to C enum name convention. +// E.g. `L2_POOL_2D` becomes `kTfLiteBuiltinL2Pool2d`. +std::string ConstantizeVariableName(const std::string& name); + +// The function generates a pure C header for builtin ops definition, and write +// it to the output stream. +bool GenerateHeader(std::ostream& os); + +} // namespace builtin_ops_header +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_SCHEMA_BUILTIN_OPS_HEADER_GENERATOR_H_ diff --git a/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc b/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7dc8e1b0486eda6e09f38a209dca95c0317a1fb --- /dev/null +++ b/tensorflow/contrib/lite/schema/builtin_ops_header/generator_test.cc @@ -0,0 +1,63 @@ + +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/schema/builtin_ops_header/generator.h" +#include +#include + +namespace { + +using tflite::builtin_ops_header::ConstantizeVariableName; +using tflite::builtin_ops_header::IsValidInputEnumName; + +TEST(TestIsValidInputEnumName, TestWithValidInputNames) { + EXPECT_TRUE(IsValidInputEnumName("ADD")); + EXPECT_TRUE(IsValidInputEnumName("CONV_2D")); + EXPECT_TRUE(IsValidInputEnumName("L2_POOL_2D")); +} + +TEST(TestIsValidInputEnumName, TestWithLeadingUnderscore) { + EXPECT_FALSE(IsValidInputEnumName("_ADD")); + EXPECT_FALSE(IsValidInputEnumName("_CONV_2D")); +} + +TEST(TestIsValidInputEnumName, TestWithLowerCase) { + EXPECT_FALSE(IsValidInputEnumName("_AdD")); + EXPECT_FALSE(IsValidInputEnumName("_COnV_2D")); +} + +TEST(TestIsValidInputEnumName, TestWithOtherCharacters) { + EXPECT_FALSE(IsValidInputEnumName("_AdD!2D")); + EXPECT_FALSE(IsValidInputEnumName("_COnV?2D")); +} + +TEST(TestIsValidInputEnumName, TestWithDoubleUnderscores) { + EXPECT_FALSE(IsValidInputEnumName("ADD__2D")); + EXPECT_FALSE(IsValidInputEnumName("CONV__2D")); +} + +TEST(TestConstantizeVariableName, TestWithValidInputNames) { + EXPECT_EQ(ConstantizeVariableName("ADD"), "kTfLiteBuiltinAdd"); + EXPECT_EQ(ConstantizeVariableName("CONV_2D"), "kTfLiteBuiltinConv2d"); + EXPECT_EQ(ConstantizeVariableName("L2_POOL_2D"), "kTfLiteBuiltinL2Pool2d"); +} + +} // anonymous namespace + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 36cc2724eb1d927d39cff25a46a57aca4f572547..98ac0469d1b885aa8047d35c8d814da4b61eff0c 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -120,6 +120,15 @@ enum BuiltinOperator : byte { UNIDIRECTIONAL_SEQUENCE_LSTM = 44, STRIDED_SLICE = 45, BIDIRECTIONAL_SEQUENCE_RNN = 46, + EXP = 47, + TOPK_V2 = 48, + SPLIT = 49, + LOG_SOFTMAX = 50, + // DELEGATE is a special op type for the operations which are delegated to + // other backends. + // WARNING: Experimental interface, subject to change + DELEGATE = 51, + BIDIRECTIONAL_SEQUENCE_LSTM = 52, } // Options for the builtin operators. @@ -156,6 +165,10 @@ union BuiltinOptions { SqueezeOptions, SequenceRNNOptions, StridedSliceOptions, + ExpOptions, + TopKV2Options, + SplitOptions, + LogSoftmaxOptions, } enum Padding : byte { SAME, VALID } @@ -315,6 +328,9 @@ table DivOptions { fused_activation_function:ActivationFunctionType; } +table TopKV2Options { +} + enum CombinerType : byte { SUM = 0, MEAN = 1, @@ -332,6 +348,9 @@ table GatherOptions { table TransposeOptions { } +table ExpOptions { +} + table MeanOptions { keep_dims: bool; } @@ -340,6 +359,10 @@ table SqueezeOptions { squeeze_dims:[int]; } +table SplitOptions { + num_splits: int; +} + table StridedSliceOptions { begin_mask: int; end_mask: int; @@ -348,6 +371,9 @@ table StridedSliceOptions { shrink_axis_mask: int; } +table LogSoftmaxOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index e2ac0b9d1e05cdc7e89da32107044320d6e4ea5a..99e1accaa71ffc92514595a745fcb60115ef61a0 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ // automatically generated by the FlatBuffers compiler, do not modify + #ifndef FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ #define FLATBUFFERS_GENERATED_SCHEMA_TFLITE_H_ @@ -108,6 +109,9 @@ struct SubOptionsT; struct DivOptions; struct DivOptionsT; +struct TopKV2Options; +struct TopKV2OptionsT; + struct EmbeddingLookupSparseOptions; struct EmbeddingLookupSparseOptionsT; @@ -117,15 +121,24 @@ struct GatherOptionsT; struct TransposeOptions; struct TransposeOptionsT; +struct ExpOptions; +struct ExpOptionsT; + struct MeanOptions; struct MeanOptionsT; struct SqueezeOptions; struct SqueezeOptionsT; +struct SplitOptions; +struct SplitOptionsT; + struct StridedSliceOptions; struct StridedSliceOptionsT; +struct LogSoftmaxOptions; +struct LogSoftmaxOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -153,15 +166,27 @@ enum TensorType { }; inline TensorType (&EnumValuesTensorType())[6] { - static TensorType values[] = {TensorType_FLOAT32, TensorType_FLOAT16, - TensorType_INT32, TensorType_UINT8, - TensorType_INT64, TensorType_STRING}; + static TensorType values[] = { + TensorType_FLOAT32, + TensorType_FLOAT16, + TensorType_INT32, + TensorType_UINT8, + TensorType_INT64, + TensorType_STRING + }; return values; } inline const char **EnumNamesTensorType() { - static const char *names[] = {"FLOAT32", "FLOAT16", "INT32", "UINT8", - "INT64", "STRING", nullptr}; + static const char *names[] = { + "FLOAT32", + "FLOAT16", + "INT32", + "UINT8", + "INT64", + "STRING", + nullptr + }; return names; } @@ -215,108 +240,129 @@ enum BuiltinOperator { BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44, BuiltinOperator_STRIDED_SLICE = 45, BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46, + BuiltinOperator_EXP = 47, + BuiltinOperator_TOPK_V2 = 48, + BuiltinOperator_SPLIT = 49, + BuiltinOperator_LOG_SOFTMAX = 50, + BuiltinOperator_DELEGATE = 51, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM = 52, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN + BuiltinOperator_MAX = BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[44] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[50] { static BuiltinOperator values[] = { - BuiltinOperator_ADD, - BuiltinOperator_AVERAGE_POOL_2D, - BuiltinOperator_CONCATENATION, - BuiltinOperator_CONV_2D, - BuiltinOperator_DEPTHWISE_CONV_2D, - BuiltinOperator_EMBEDDING_LOOKUP, - BuiltinOperator_FULLY_CONNECTED, - BuiltinOperator_HASHTABLE_LOOKUP, - BuiltinOperator_L2_NORMALIZATION, - BuiltinOperator_L2_POOL_2D, - BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, - BuiltinOperator_LOGISTIC, - BuiltinOperator_LSH_PROJECTION, - BuiltinOperator_LSTM, - BuiltinOperator_MAX_POOL_2D, - BuiltinOperator_MUL, - BuiltinOperator_RELU, - BuiltinOperator_RELU_N1_TO_1, - BuiltinOperator_RELU6, - BuiltinOperator_RESHAPE, - BuiltinOperator_RESIZE_BILINEAR, - BuiltinOperator_RNN, - BuiltinOperator_SOFTMAX, - BuiltinOperator_SPACE_TO_DEPTH, - BuiltinOperator_SVDF, - BuiltinOperator_TANH, - BuiltinOperator_CONCAT_EMBEDDINGS, - BuiltinOperator_SKIP_GRAM, - BuiltinOperator_CALL, - BuiltinOperator_CUSTOM, - BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, - BuiltinOperator_PAD, - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOperator_GATHER, - BuiltinOperator_BATCH_TO_SPACE_ND, - BuiltinOperator_SPACE_TO_BATCH_ND, - BuiltinOperator_TRANSPOSE, - BuiltinOperator_MEAN, - BuiltinOperator_SUB, - BuiltinOperator_DIV, - BuiltinOperator_SQUEEZE, - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOperator_STRIDED_SLICE, - BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN}; + BuiltinOperator_ADD, + BuiltinOperator_AVERAGE_POOL_2D, + BuiltinOperator_CONCATENATION, + BuiltinOperator_CONV_2D, + BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_EMBEDDING_LOOKUP, + BuiltinOperator_FULLY_CONNECTED, + BuiltinOperator_HASHTABLE_LOOKUP, + BuiltinOperator_L2_NORMALIZATION, + BuiltinOperator_L2_POOL_2D, + BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, + BuiltinOperator_LOGISTIC, + BuiltinOperator_LSH_PROJECTION, + BuiltinOperator_LSTM, + BuiltinOperator_MAX_POOL_2D, + BuiltinOperator_MUL, + BuiltinOperator_RELU, + BuiltinOperator_RELU_N1_TO_1, + BuiltinOperator_RELU6, + BuiltinOperator_RESHAPE, + BuiltinOperator_RESIZE_BILINEAR, + BuiltinOperator_RNN, + BuiltinOperator_SOFTMAX, + BuiltinOperator_SPACE_TO_DEPTH, + BuiltinOperator_SVDF, + BuiltinOperator_TANH, + BuiltinOperator_CONCAT_EMBEDDINGS, + BuiltinOperator_SKIP_GRAM, + BuiltinOperator_CALL, + BuiltinOperator_CUSTOM, + BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, + BuiltinOperator_PAD, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_GATHER, + BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOperator_TRANSPOSE, + BuiltinOperator_MEAN, + BuiltinOperator_SUB, + BuiltinOperator_DIV, + BuiltinOperator_SQUEEZE, + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOperator_STRIDED_SLICE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_EXP, + BuiltinOperator_TOPK_V2, + BuiltinOperator_SPLIT, + BuiltinOperator_LOG_SOFTMAX, + BuiltinOperator_DELEGATE, + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM + }; return values; } inline const char **EnumNamesBuiltinOperator() { - static const char *names[] = {"ADD", - "AVERAGE_POOL_2D", - "CONCATENATION", - "CONV_2D", - "DEPTHWISE_CONV_2D", - "", - "", - "EMBEDDING_LOOKUP", - "", - "FULLY_CONNECTED", - "HASHTABLE_LOOKUP", - "L2_NORMALIZATION", - "L2_POOL_2D", - "LOCAL_RESPONSE_NORMALIZATION", - "LOGISTIC", - "LSH_PROJECTION", - "LSTM", - "MAX_POOL_2D", - "MUL", - "RELU", - "RELU_N1_TO_1", - "RELU6", - "RESHAPE", - "RESIZE_BILINEAR", - "RNN", - "SOFTMAX", - "SPACE_TO_DEPTH", - "SVDF", - "TANH", - "CONCAT_EMBEDDINGS", - "SKIP_GRAM", - "CALL", - "CUSTOM", - "EMBEDDING_LOOKUP_SPARSE", - "PAD", - "UNIDIRECTIONAL_SEQUENCE_RNN", - "GATHER", - "BATCH_TO_SPACE_ND", - "SPACE_TO_BATCH_ND", - "TRANSPOSE", - "MEAN", - "SUB", - "DIV", - "SQUEEZE", - "UNIDIRECTIONAL_SEQUENCE_LSTM", - "STRIDED_SLICE", - "BIDIRECTIONAL_SEQUENCE_RNN", - nullptr}; + static const char *names[] = { + "ADD", + "AVERAGE_POOL_2D", + "CONCATENATION", + "CONV_2D", + "DEPTHWISE_CONV_2D", + "", + "", + "EMBEDDING_LOOKUP", + "", + "FULLY_CONNECTED", + "HASHTABLE_LOOKUP", + "L2_NORMALIZATION", + "L2_POOL_2D", + "LOCAL_RESPONSE_NORMALIZATION", + "LOGISTIC", + "LSH_PROJECTION", + "LSTM", + "MAX_POOL_2D", + "MUL", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "RESHAPE", + "RESIZE_BILINEAR", + "RNN", + "SOFTMAX", + "SPACE_TO_DEPTH", + "SVDF", + "TANH", + "CONCAT_EMBEDDINGS", + "SKIP_GRAM", + "CALL", + "CUSTOM", + "EMBEDDING_LOOKUP_SPARSE", + "PAD", + "UNIDIRECTIONAL_SEQUENCE_RNN", + "GATHER", + "BATCH_TO_SPACE_ND", + "SPACE_TO_BATCH_ND", + "TRANSPOSE", + "MEAN", + "SUB", + "DIV", + "SQUEEZE", + "UNIDIRECTIONAL_SEQUENCE_LSTM", + "STRIDED_SLICE", + "BIDIRECTIONAL_SEQUENCE_RNN", + "EXP", + "TOPK_V2", + "SPLIT", + "LOG_SOFTMAX", + "DELEGATE", + "BIDIRECTIONAL_SEQUENCE_LSTM", + nullptr + }; return names; } @@ -359,83 +405,98 @@ enum BuiltinOptions { BuiltinOptions_SqueezeOptions = 30, BuiltinOptions_SequenceRNNOptions = 31, BuiltinOptions_StridedSliceOptions = 32, + BuiltinOptions_ExpOptions = 33, + BuiltinOptions_TopKV2Options = 34, + BuiltinOptions_SplitOptions = 35, + BuiltinOptions_LogSoftmaxOptions = 36, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions + BuiltinOptions_MAX = BuiltinOptions_LogSoftmaxOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[37] { static BuiltinOptions values[] = { - BuiltinOptions_NONE, - BuiltinOptions_Conv2DOptions, - BuiltinOptions_DepthwiseConv2DOptions, - BuiltinOptions_ConcatEmbeddingsOptions, - BuiltinOptions_LSHProjectionOptions, - BuiltinOptions_Pool2DOptions, - BuiltinOptions_SVDFOptions, - BuiltinOptions_RNNOptions, - BuiltinOptions_FullyConnectedOptions, - BuiltinOptions_SoftmaxOptions, - BuiltinOptions_ConcatenationOptions, - BuiltinOptions_AddOptions, - BuiltinOptions_L2NormOptions, - BuiltinOptions_LocalResponseNormalizationOptions, - BuiltinOptions_LSTMOptions, - BuiltinOptions_ResizeBilinearOptions, - BuiltinOptions_CallOptions, - BuiltinOptions_ReshapeOptions, - BuiltinOptions_SkipGramOptions, - BuiltinOptions_SpaceToDepthOptions, - BuiltinOptions_EmbeddingLookupSparseOptions, - BuiltinOptions_MulOptions, - BuiltinOptions_PadOptions, - BuiltinOptions_GatherOptions, - BuiltinOptions_BatchToSpaceNDOptions, - BuiltinOptions_SpaceToBatchNDOptions, - BuiltinOptions_TransposeOptions, - BuiltinOptions_MeanOptions, - BuiltinOptions_SubOptions, - BuiltinOptions_DivOptions, - BuiltinOptions_SqueezeOptions, - BuiltinOptions_SequenceRNNOptions, - BuiltinOptions_StridedSliceOptions}; + BuiltinOptions_NONE, + BuiltinOptions_Conv2DOptions, + BuiltinOptions_DepthwiseConv2DOptions, + BuiltinOptions_ConcatEmbeddingsOptions, + BuiltinOptions_LSHProjectionOptions, + BuiltinOptions_Pool2DOptions, + BuiltinOptions_SVDFOptions, + BuiltinOptions_RNNOptions, + BuiltinOptions_FullyConnectedOptions, + BuiltinOptions_SoftmaxOptions, + BuiltinOptions_ConcatenationOptions, + BuiltinOptions_AddOptions, + BuiltinOptions_L2NormOptions, + BuiltinOptions_LocalResponseNormalizationOptions, + BuiltinOptions_LSTMOptions, + BuiltinOptions_ResizeBilinearOptions, + BuiltinOptions_CallOptions, + BuiltinOptions_ReshapeOptions, + BuiltinOptions_SkipGramOptions, + BuiltinOptions_SpaceToDepthOptions, + BuiltinOptions_EmbeddingLookupSparseOptions, + BuiltinOptions_MulOptions, + BuiltinOptions_PadOptions, + BuiltinOptions_GatherOptions, + BuiltinOptions_BatchToSpaceNDOptions, + BuiltinOptions_SpaceToBatchNDOptions, + BuiltinOptions_TransposeOptions, + BuiltinOptions_MeanOptions, + BuiltinOptions_SubOptions, + BuiltinOptions_DivOptions, + BuiltinOptions_SqueezeOptions, + BuiltinOptions_SequenceRNNOptions, + BuiltinOptions_StridedSliceOptions, + BuiltinOptions_ExpOptions, + BuiltinOptions_TopKV2Options, + BuiltinOptions_SplitOptions, + BuiltinOptions_LogSoftmaxOptions + }; return values; } inline const char **EnumNamesBuiltinOptions() { - static const char *names[] = {"NONE", - "Conv2DOptions", - "DepthwiseConv2DOptions", - "ConcatEmbeddingsOptions", - "LSHProjectionOptions", - "Pool2DOptions", - "SVDFOptions", - "RNNOptions", - "FullyConnectedOptions", - "SoftmaxOptions", - "ConcatenationOptions", - "AddOptions", - "L2NormOptions", - "LocalResponseNormalizationOptions", - "LSTMOptions", - "ResizeBilinearOptions", - "CallOptions", - "ReshapeOptions", - "SkipGramOptions", - "SpaceToDepthOptions", - "EmbeddingLookupSparseOptions", - "MulOptions", - "PadOptions", - "GatherOptions", - "BatchToSpaceNDOptions", - "SpaceToBatchNDOptions", - "TransposeOptions", - "MeanOptions", - "SubOptions", - "DivOptions", - "SqueezeOptions", - "SequenceRNNOptions", - "StridedSliceOptions", - nullptr}; + static const char *names[] = { + "NONE", + "Conv2DOptions", + "DepthwiseConv2DOptions", + "ConcatEmbeddingsOptions", + "LSHProjectionOptions", + "Pool2DOptions", + "SVDFOptions", + "RNNOptions", + "FullyConnectedOptions", + "SoftmaxOptions", + "ConcatenationOptions", + "AddOptions", + "L2NormOptions", + "LocalResponseNormalizationOptions", + "LSTMOptions", + "ResizeBilinearOptions", + "CallOptions", + "ReshapeOptions", + "SkipGramOptions", + "SpaceToDepthOptions", + "EmbeddingLookupSparseOptions", + "MulOptions", + "PadOptions", + "GatherOptions", + "BatchToSpaceNDOptions", + "SpaceToBatchNDOptions", + "TransposeOptions", + "MeanOptions", + "SubOptions", + "DivOptions", + "SqueezeOptions", + "SequenceRNNOptions", + "StridedSliceOptions", + "ExpOptions", + "TopKV2Options", + "SplitOptions", + "LogSoftmaxOptions", + nullptr + }; return names; } @@ -444,206 +505,174 @@ inline const char *EnumNameBuiltinOptions(BuiltinOptions e) { return EnumNamesBuiltinOptions()[index]; } -template -struct BuiltinOptionsTraits { +template struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_NONE; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_Conv2DOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_DepthwiseConv2DOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DepthwiseConv2DOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_ConcatEmbeddingsOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ConcatEmbeddingsOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LSHProjectionOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_Pool2DOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SVDFOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_RNNOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_FullyConnectedOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SoftmaxOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ConcatenationOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_AddOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_L2NormOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_LocalResponseNormalizationOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LocalResponseNormalizationOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_LSTMOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ResizeBilinearOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_CallOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_ReshapeOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SkipGramOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SpaceToDepthOptions; }; -template <> -struct BuiltinOptionsTraits { - static const BuiltinOptions enum_value = - BuiltinOptions_EmbeddingLookupSparseOptions; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_EmbeddingLookupSparseOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MulOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_BatchToSpaceNDOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SpaceToBatchNDOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_TransposeOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_MeanOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SubOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_DivOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SqueezeOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_SequenceRNNOptions; }; -template <> -struct BuiltinOptionsTraits { +template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_TopKV2Options; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_SplitOptions; +}; + +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_LogSoftmaxOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; BuiltinOptionsUnion() : type(BuiltinOptions_NONE), value(nullptr) {} - BuiltinOptionsUnion(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT - : type(BuiltinOptions_NONE), - value(nullptr) { - std::swap(type, u.type); - std::swap(value, u.value); - } + BuiltinOptionsUnion(BuiltinOptionsUnion&& u) FLATBUFFERS_NOEXCEPT : + type(BuiltinOptions_NONE), value(nullptr) + { std::swap(type, u.type); std::swap(value, u.value); } BuiltinOptionsUnion(const BuiltinOptionsUnion &) FLATBUFFERS_NOEXCEPT; - BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) - FLATBUFFERS_NOEXCEPT { - BuiltinOptionsUnion t(u); - std::swap(type, t.type); - std::swap(value, t.value); - return *this; - } - BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT { - std::swap(type, u.type); - std::swap(value, u.value); - return *this; - } + BuiltinOptionsUnion &operator=(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT + { BuiltinOptionsUnion t(u); std::swap(type, t.type); std::swap(value, t.value); return *this; } + BuiltinOptionsUnion &operator=(BuiltinOptionsUnion &&u) FLATBUFFERS_NOEXCEPT + { std::swap(type, u.type); std::swap(value, u.value); return *this; } ~BuiltinOptionsUnion() { Reset(); } void Reset(); #ifndef FLATBUFFERS_CPP98_STL template - void Set(T &&val) { + void Set(T&& val) { Reset(); type = BuiltinOptionsTraits::enum_value; if (type != BuiltinOptions_NONE) { @@ -652,342 +681,301 @@ struct BuiltinOptionsUnion { } #endif // FLATBUFFERS_CPP98_STL - static void *UnPack(const void *obj, BuiltinOptions type, - const flatbuffers::resolver_function_t *resolver); - flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; + static void *UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver); + flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher = nullptr) const; Conv2DOptionsT *AsConv2DOptions() { - return type == BuiltinOptions_Conv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; } const Conv2DOptionsT *AsConv2DOptions() const { - return type == BuiltinOptions_Conv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Conv2DOptions ? + reinterpret_cast(value) : nullptr; } DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() { - return type == BuiltinOptions_DepthwiseConv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; } const DepthwiseConv2DOptionsT *AsDepthwiseConv2DOptions() const { - return type == BuiltinOptions_DepthwiseConv2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DepthwiseConv2DOptions ? + reinterpret_cast(value) : nullptr; } ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() { - return type == BuiltinOptions_ConcatEmbeddingsOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; } const ConcatEmbeddingsOptionsT *AsConcatEmbeddingsOptions() const { - return type == BuiltinOptions_ConcatEmbeddingsOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatEmbeddingsOptions ? + reinterpret_cast(value) : nullptr; } LSHProjectionOptionsT *AsLSHProjectionOptions() { - return type == BuiltinOptions_LSHProjectionOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; } const LSHProjectionOptionsT *AsLSHProjectionOptions() const { - return type == BuiltinOptions_LSHProjectionOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSHProjectionOptions ? + reinterpret_cast(value) : nullptr; } Pool2DOptionsT *AsPool2DOptions() { - return type == BuiltinOptions_Pool2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; } const Pool2DOptionsT *AsPool2DOptions() const { - return type == BuiltinOptions_Pool2DOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_Pool2DOptions ? + reinterpret_cast(value) : nullptr; } SVDFOptionsT *AsSVDFOptions() { - return type == BuiltinOptions_SVDFOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; } const SVDFOptionsT *AsSVDFOptions() const { - return type == BuiltinOptions_SVDFOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SVDFOptions ? + reinterpret_cast(value) : nullptr; } RNNOptionsT *AsRNNOptions() { - return type == BuiltinOptions_RNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; } const RNNOptionsT *AsRNNOptions() const { - return type == BuiltinOptions_RNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_RNNOptions ? + reinterpret_cast(value) : nullptr; } FullyConnectedOptionsT *AsFullyConnectedOptions() { - return type == BuiltinOptions_FullyConnectedOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; } const FullyConnectedOptionsT *AsFullyConnectedOptions() const { - return type == BuiltinOptions_FullyConnectedOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_FullyConnectedOptions ? + reinterpret_cast(value) : nullptr; } SoftmaxOptionsT *AsSoftmaxOptions() { - return type == BuiltinOptions_SoftmaxOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; } const SoftmaxOptionsT *AsSoftmaxOptions() const { - return type == BuiltinOptions_SoftmaxOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SoftmaxOptions ? + reinterpret_cast(value) : nullptr; } ConcatenationOptionsT *AsConcatenationOptions() { - return type == BuiltinOptions_ConcatenationOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; } const ConcatenationOptionsT *AsConcatenationOptions() const { - return type == BuiltinOptions_ConcatenationOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ConcatenationOptions ? + reinterpret_cast(value) : nullptr; } AddOptionsT *AsAddOptions() { - return type == BuiltinOptions_AddOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; } const AddOptionsT *AsAddOptions() const { - return type == BuiltinOptions_AddOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_AddOptions ? + reinterpret_cast(value) : nullptr; } L2NormOptionsT *AsL2NormOptions() { - return type == BuiltinOptions_L2NormOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; } const L2NormOptionsT *AsL2NormOptions() const { - return type == BuiltinOptions_L2NormOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_L2NormOptions ? + reinterpret_cast(value) : nullptr; } LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() { - return type == BuiltinOptions_LocalResponseNormalizationOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; } - const LocalResponseNormalizationOptionsT * - AsLocalResponseNormalizationOptions() const { - return type == BuiltinOptions_LocalResponseNormalizationOptions - ? reinterpret_cast( - value) - : nullptr; + const LocalResponseNormalizationOptionsT *AsLocalResponseNormalizationOptions() const { + return type == BuiltinOptions_LocalResponseNormalizationOptions ? + reinterpret_cast(value) : nullptr; } LSTMOptionsT *AsLSTMOptions() { - return type == BuiltinOptions_LSTMOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; } const LSTMOptionsT *AsLSTMOptions() const { - return type == BuiltinOptions_LSTMOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_LSTMOptions ? + reinterpret_cast(value) : nullptr; } ResizeBilinearOptionsT *AsResizeBilinearOptions() { - return type == BuiltinOptions_ResizeBilinearOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; } const ResizeBilinearOptionsT *AsResizeBilinearOptions() const { - return type == BuiltinOptions_ResizeBilinearOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ResizeBilinearOptions ? + reinterpret_cast(value) : nullptr; } CallOptionsT *AsCallOptions() { - return type == BuiltinOptions_CallOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; } const CallOptionsT *AsCallOptions() const { - return type == BuiltinOptions_CallOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_CallOptions ? + reinterpret_cast(value) : nullptr; } ReshapeOptionsT *AsReshapeOptions() { - return type == BuiltinOptions_ReshapeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; } const ReshapeOptionsT *AsReshapeOptions() const { - return type == BuiltinOptions_ReshapeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_ReshapeOptions ? + reinterpret_cast(value) : nullptr; } SkipGramOptionsT *AsSkipGramOptions() { - return type == BuiltinOptions_SkipGramOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; } const SkipGramOptionsT *AsSkipGramOptions() const { - return type == BuiltinOptions_SkipGramOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SkipGramOptions ? + reinterpret_cast(value) : nullptr; } SpaceToDepthOptionsT *AsSpaceToDepthOptions() { - return type == BuiltinOptions_SpaceToDepthOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; } const SpaceToDepthOptionsT *AsSpaceToDepthOptions() const { - return type == BuiltinOptions_SpaceToDepthOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToDepthOptions ? + reinterpret_cast(value) : nullptr; } EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() { - return type == BuiltinOptions_EmbeddingLookupSparseOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; } const EmbeddingLookupSparseOptionsT *AsEmbeddingLookupSparseOptions() const { - return type == BuiltinOptions_EmbeddingLookupSparseOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_EmbeddingLookupSparseOptions ? + reinterpret_cast(value) : nullptr; } MulOptionsT *AsMulOptions() { - return type == BuiltinOptions_MulOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; } const MulOptionsT *AsMulOptions() const { - return type == BuiltinOptions_MulOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MulOptions ? + reinterpret_cast(value) : nullptr; } PadOptionsT *AsPadOptions() { - return type == BuiltinOptions_PadOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_PadOptions ? + reinterpret_cast(value) : nullptr; } const PadOptionsT *AsPadOptions() const { - return type == BuiltinOptions_PadOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_PadOptions ? + reinterpret_cast(value) : nullptr; } GatherOptionsT *AsGatherOptions() { - return type == BuiltinOptions_GatherOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_GatherOptions ? + reinterpret_cast(value) : nullptr; } const GatherOptionsT *AsGatherOptions() const { - return type == BuiltinOptions_GatherOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_GatherOptions ? + reinterpret_cast(value) : nullptr; } BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() { - return type == BuiltinOptions_BatchToSpaceNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_BatchToSpaceNDOptions ? + reinterpret_cast(value) : nullptr; } const BatchToSpaceNDOptionsT *AsBatchToSpaceNDOptions() const { - return type == BuiltinOptions_BatchToSpaceNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_BatchToSpaceNDOptions ? + reinterpret_cast(value) : nullptr; } SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() { - return type == BuiltinOptions_SpaceToBatchNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToBatchNDOptions ? + reinterpret_cast(value) : nullptr; } const SpaceToBatchNDOptionsT *AsSpaceToBatchNDOptions() const { - return type == BuiltinOptions_SpaceToBatchNDOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SpaceToBatchNDOptions ? + reinterpret_cast(value) : nullptr; } TransposeOptionsT *AsTransposeOptions() { - return type == BuiltinOptions_TransposeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_TransposeOptions ? + reinterpret_cast(value) : nullptr; } const TransposeOptionsT *AsTransposeOptions() const { - return type == BuiltinOptions_TransposeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_TransposeOptions ? + reinterpret_cast(value) : nullptr; } MeanOptionsT *AsMeanOptions() { - return type == BuiltinOptions_MeanOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MeanOptions ? + reinterpret_cast(value) : nullptr; } const MeanOptionsT *AsMeanOptions() const { - return type == BuiltinOptions_MeanOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_MeanOptions ? + reinterpret_cast(value) : nullptr; } SubOptionsT *AsSubOptions() { - return type == BuiltinOptions_SubOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SubOptions ? + reinterpret_cast(value) : nullptr; } const SubOptionsT *AsSubOptions() const { - return type == BuiltinOptions_SubOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SubOptions ? + reinterpret_cast(value) : nullptr; } DivOptionsT *AsDivOptions() { - return type == BuiltinOptions_DivOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DivOptions ? + reinterpret_cast(value) : nullptr; } const DivOptionsT *AsDivOptions() const { - return type == BuiltinOptions_DivOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_DivOptions ? + reinterpret_cast(value) : nullptr; } SqueezeOptionsT *AsSqueezeOptions() { - return type == BuiltinOptions_SqueezeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SqueezeOptions ? + reinterpret_cast(value) : nullptr; } const SqueezeOptionsT *AsSqueezeOptions() const { - return type == BuiltinOptions_SqueezeOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SqueezeOptions ? + reinterpret_cast(value) : nullptr; } SequenceRNNOptionsT *AsSequenceRNNOptions() { - return type == BuiltinOptions_SequenceRNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SequenceRNNOptions ? + reinterpret_cast(value) : nullptr; } const SequenceRNNOptionsT *AsSequenceRNNOptions() const { - return type == BuiltinOptions_SequenceRNNOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_SequenceRNNOptions ? + reinterpret_cast(value) : nullptr; } StridedSliceOptionsT *AsStridedSliceOptions() { - return type == BuiltinOptions_StridedSliceOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_StridedSliceOptions ? + reinterpret_cast(value) : nullptr; } const StridedSliceOptionsT *AsStridedSliceOptions() const { - return type == BuiltinOptions_StridedSliceOptions - ? reinterpret_cast(value) - : nullptr; + return type == BuiltinOptions_StridedSliceOptions ? + reinterpret_cast(value) : nullptr; + } + ExpOptionsT *AsExpOptions() { + return type == BuiltinOptions_ExpOptions ? + reinterpret_cast(value) : nullptr; + } + const ExpOptionsT *AsExpOptions() const { + return type == BuiltinOptions_ExpOptions ? + reinterpret_cast(value) : nullptr; + } + TopKV2OptionsT *AsTopKV2Options() { + return type == BuiltinOptions_TopKV2Options ? + reinterpret_cast(value) : nullptr; + } + const TopKV2OptionsT *AsTopKV2Options() const { + return type == BuiltinOptions_TopKV2Options ? + reinterpret_cast(value) : nullptr; + } + SplitOptionsT *AsSplitOptions() { + return type == BuiltinOptions_SplitOptions ? + reinterpret_cast(value) : nullptr; + } + const SplitOptionsT *AsSplitOptions() const { + return type == BuiltinOptions_SplitOptions ? + reinterpret_cast(value) : nullptr; + } + LogSoftmaxOptionsT *AsLogSoftmaxOptions() { + return type == BuiltinOptions_LogSoftmaxOptions ? + reinterpret_cast(value) : nullptr; + } + const LogSoftmaxOptionsT *AsLogSoftmaxOptions() const { + return type == BuiltinOptions_LogSoftmaxOptions ? + reinterpret_cast(value) : nullptr; } }; -bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, - BuiltinOptions type); -bool VerifyBuiltinOptionsVector( - flatbuffers::Verifier &verifier, - const flatbuffers::Vector> *values, - const flatbuffers::Vector *types); +bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); +bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); enum Padding { Padding_SAME = 0, @@ -997,12 +985,19 @@ enum Padding { }; inline Padding (&EnumValuesPadding())[2] { - static Padding values[] = {Padding_SAME, Padding_VALID}; + static Padding values[] = { + Padding_SAME, + Padding_VALID + }; return values; } inline const char **EnumNamesPadding() { - static const char *names[] = {"SAME", "VALID", nullptr}; + static const char *names[] = { + "SAME", + "VALID", + nullptr + }; return names; } @@ -1024,15 +1019,26 @@ enum ActivationFunctionType { inline ActivationFunctionType (&EnumValuesActivationFunctionType())[6] { static ActivationFunctionType values[] = { - ActivationFunctionType_NONE, ActivationFunctionType_RELU, - ActivationFunctionType_RELU_N1_TO_1, ActivationFunctionType_RELU6, - ActivationFunctionType_TANH, ActivationFunctionType_SIGN_BIT}; + ActivationFunctionType_NONE, + ActivationFunctionType_RELU, + ActivationFunctionType_RELU_N1_TO_1, + ActivationFunctionType_RELU6, + ActivationFunctionType_TANH, + ActivationFunctionType_SIGN_BIT + }; return values; } inline const char **EnumNamesActivationFunctionType() { - static const char *names[] = {"NONE", "RELU", "RELU_N1_TO_1", "RELU6", - "TANH", "SIGN_BIT", nullptr}; + static const char *names[] = { + "NONE", + "RELU", + "RELU_N1_TO_1", + "RELU6", + "TANH", + "SIGN_BIT", + nullptr + }; return names; } @@ -1050,14 +1056,21 @@ enum LSHProjectionType { }; inline LSHProjectionType (&EnumValuesLSHProjectionType())[3] { - static LSHProjectionType values[] = {LSHProjectionType_UNKNOWN, - LSHProjectionType_SPARSE, - LSHProjectionType_DENSE}; + static LSHProjectionType values[] = { + LSHProjectionType_UNKNOWN, + LSHProjectionType_SPARSE, + LSHProjectionType_DENSE + }; return values; } inline const char **EnumNamesLSHProjectionType() { - static const char *names[] = {"UNKNOWN", "SPARSE", "DENSE", nullptr}; + static const char *names[] = { + "UNKNOWN", + "SPARSE", + "DENSE", + nullptr + }; return names; } @@ -1075,13 +1088,21 @@ enum CombinerType { }; inline CombinerType (&EnumValuesCombinerType())[3] { - static CombinerType values[] = {CombinerType_SUM, CombinerType_MEAN, - CombinerType_SQRTN}; + static CombinerType values[] = { + CombinerType_SUM, + CombinerType_MEAN, + CombinerType_SQRTN + }; return values; } inline const char **EnumNamesCombinerType() { - static const char *names[] = {"SUM", "MEAN", "SQRTN", nullptr}; + static const char *names[] = { + "SUM", + "MEAN", + "SQRTN", + nullptr + }; return names; } @@ -1097,12 +1118,17 @@ enum CustomOptionsFormat { }; inline CustomOptionsFormat (&EnumValuesCustomOptionsFormat())[1] { - static CustomOptionsFormat values[] = {CustomOptionsFormat_FLEXBUFFERS}; + static CustomOptionsFormat values[] = { + CustomOptionsFormat_FLEXBUFFERS + }; return values; } inline const char **EnumNamesCustomOptionsFormat() { - static const char *names[] = {"FLEXBUFFERS", nullptr}; + static const char *names[] = { + "FLEXBUFFERS", + nullptr + }; return names; } @@ -1117,13 +1143,18 @@ struct QuantizationParametersT : public flatbuffers::NativeTable { std::vector max; std::vector scale; std::vector zero_point; - QuantizationParametersT() {} + QuantizationParametersT() { + } }; -struct QuantizationParameters FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct QuantizationParameters FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef QuantizationParametersT NativeTableType; - enum { VT_MIN = 4, VT_MAX = 6, VT_SCALE = 8, VT_ZERO_POINT = 10 }; + enum { + VT_MIN = 4, + VT_MAX = 6, + VT_SCALE = 8, + VT_ZERO_POINT = 10 + }; const flatbuffers::Vector *min() const { return GetPointer *>(VT_MIN); } @@ -1137,20 +1168,20 @@ struct QuantizationParameters FLATBUFFERS_FINAL_CLASS return GetPointer *>(VT_ZERO_POINT); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_MIN) && - verifier.Verify(min()) && VerifyOffset(verifier, VT_MAX) && - verifier.Verify(max()) && VerifyOffset(verifier, VT_SCALE) && - verifier.Verify(scale()) && VerifyOffset(verifier, VT_ZERO_POINT) && - verifier.Verify(zero_point()) && verifier.EndTable(); - } - QuantizationParametersT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - QuantizationParametersT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_MIN) && + verifier.Verify(min()) && + VerifyOffset(verifier, VT_MAX) && + verifier.Verify(max()) && + VerifyOffset(verifier, VT_SCALE) && + verifier.Verify(scale()) && + VerifyOffset(verifier, VT_ZERO_POINT) && + verifier.Verify(zero_point()) && + verifier.EndTable(); + } + QuantizationParametersT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct QuantizationParametersBuilder { @@ -1165,16 +1196,14 @@ struct QuantizationParametersBuilder { void add_scale(flatbuffers::Offset> scale) { fbb_.AddOffset(QuantizationParameters::VT_SCALE, scale); } - void add_zero_point( - flatbuffers::Offset> zero_point) { + void add_zero_point(flatbuffers::Offset> zero_point) { fbb_.AddOffset(QuantizationParameters::VT_ZERO_POINT, zero_point); } explicit QuantizationParametersBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - QuantizationParametersBuilder &operator=( - const QuantizationParametersBuilder &); + QuantizationParametersBuilder &operator=(const QuantizationParametersBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1196,23 +1225,21 @@ inline flatbuffers::Offset CreateQuantizationParameters( return builder_.Finish(); } -inline flatbuffers::Offset -CreateQuantizationParametersDirect( +inline flatbuffers::Offset CreateQuantizationParametersDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *min = nullptr, const std::vector *max = nullptr, const std::vector *scale = nullptr, const std::vector *zero_point = nullptr) { return tflite::CreateQuantizationParameters( - _fbb, min ? _fbb.CreateVector(*min) : 0, + _fbb, + min ? _fbb.CreateVector(*min) : 0, max ? _fbb.CreateVector(*max) : 0, scale ? _fbb.CreateVector(*scale) : 0, zero_point ? _fbb.CreateVector(*zero_point) : 0); } -flatbuffers::Offset CreateQuantizationParameters( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct TensorT : public flatbuffers::NativeTable { typedef Tensor TableType; @@ -1221,7 +1248,10 @@ struct TensorT : public flatbuffers::NativeTable { uint32_t buffer; std::string name; std::unique_ptr quantization; - TensorT() : type(TensorType_FLOAT32), buffer(0) {} + TensorT() + : type(TensorType_FLOAT32), + buffer(0) { + } }; struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1239,7 +1269,9 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { TensorType type() const { return static_cast(GetField(VT_TYPE, 0)); } - uint32_t buffer() const { return GetField(VT_BUFFER, 0); } + uint32_t buffer() const { + return GetField(VT_BUFFER, 0); + } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } @@ -1247,20 +1279,20 @@ struct Tensor FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer(VT_QUANTIZATION); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SHAPE) && - verifier.Verify(shape()) && VerifyField(verifier, VT_TYPE) && + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_SHAPE) && + verifier.Verify(shape()) && + VerifyField(verifier, VT_TYPE) && VerifyField(verifier, VT_BUFFER) && - VerifyOffset(verifier, VT_NAME) && verifier.Verify(name()) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && VerifyOffset(verifier, VT_QUANTIZATION) && - verifier.VerifyTable(quantization()) && verifier.EndTable(); + verifier.VerifyTable(quantization()) && + verifier.EndTable(); } - TensorT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + TensorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct TensorBuilder { @@ -1278,11 +1310,11 @@ struct TensorBuilder { void add_name(flatbuffers::Offset name) { fbb_.AddOffset(Tensor::VT_NAME, name); } - void add_quantization( - flatbuffers::Offset quantization) { + void add_quantization(flatbuffers::Offset quantization) { fbb_.AddOffset(Tensor::VT_QUANTIZATION, quantization); } - explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit TensorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } TensorBuilder &operator=(const TensorBuilder &); @@ -1296,7 +1328,8 @@ struct TensorBuilder { inline flatbuffers::Offset CreateTensor( flatbuffers::FlatBufferBuilder &_fbb, flatbuffers::Offset> shape = 0, - TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, + TensorType type = TensorType_FLOAT32, + uint32_t buffer = 0, flatbuffers::Offset name = 0, flatbuffers::Offset quantization = 0) { TensorBuilder builder_(_fbb); @@ -1311,17 +1344,20 @@ inline flatbuffers::Offset CreateTensor( inline flatbuffers::Offset CreateTensorDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *shape = nullptr, - TensorType type = TensorType_FLOAT32, uint32_t buffer = 0, + TensorType type = TensorType_FLOAT32, + uint32_t buffer = 0, const char *name = nullptr, flatbuffers::Offset quantization = 0) { return tflite::CreateTensor( - _fbb, shape ? _fbb.CreateVector(*shape) : 0, type, buffer, - name ? _fbb.CreateString(name) : 0, quantization); + _fbb, + shape ? _fbb.CreateVector(*shape) : 0, + type, + buffer, + name ? _fbb.CreateString(name) : 0, + quantization); } -flatbuffers::Offset CreateTensor( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct Conv2DOptionsT : public flatbuffers::NativeTable { typedef Conv2DOptions TableType; @@ -1333,7 +1369,8 @@ struct Conv2DOptionsT : public flatbuffers::NativeTable { : padding(Padding_SAME), stride_w(0), stride_h(0), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1347,11 +1384,14 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { Padding padding() const { return static_cast(GetField(VT_PADDING, 0)); } - int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } - int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1361,22 +1401,16 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - Conv2DOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - Conv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + Conv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct Conv2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_padding(Padding padding) { - fbb_.AddElement(Conv2DOptions::VT_PADDING, - static_cast(padding), 0); + fbb_.AddElement(Conv2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { fbb_.AddElement(Conv2DOptions::VT_STRIDE_W, stride_w, 0); @@ -1384,13 +1418,11 @@ struct Conv2DOptionsBuilder { void add_stride_h(int32_t stride_h) { fbb_.AddElement(Conv2DOptions::VT_STRIDE_H, stride_h, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit Conv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } Conv2DOptionsBuilder &operator=(const Conv2DOptionsBuilder &); @@ -1402,10 +1434,11 @@ struct Conv2DOptionsBuilder { }; inline flatbuffers::Offset CreateConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, - int32_t stride_w = 0, int32_t stride_h = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { Conv2DOptionsBuilder builder_(_fbb); builder_.add_stride_h(stride_h); builder_.add_stride_w(stride_w); @@ -1414,9 +1447,7 @@ inline flatbuffers::Offset CreateConv2DOptions( return builder_.Finish(); } -flatbuffers::Offset CreateConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct Pool2DOptionsT : public flatbuffers::NativeTable { typedef Pool2DOptions TableType; @@ -1432,7 +1463,8 @@ struct Pool2DOptionsT : public flatbuffers::NativeTable { stride_h(0), filter_width(0), filter_height(0), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -1448,15 +1480,20 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { Padding padding() const { return static_cast(GetField(VT_PADDING, 0)); } - int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } - int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } - int32_t filter_width() const { return GetField(VT_FILTER_WIDTH, 0); } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } + int32_t filter_width() const { + return GetField(VT_FILTER_WIDTH, 0); + } int32_t filter_height() const { return GetField(VT_FILTER_HEIGHT, 0); } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1468,22 +1505,16 @@ struct Pool2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - Pool2DOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - Pool2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + Pool2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct Pool2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_padding(Padding padding) { - fbb_.AddElement(Pool2DOptions::VT_PADDING, - static_cast(padding), 0); + fbb_.AddElement(Pool2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { fbb_.AddElement(Pool2DOptions::VT_STRIDE_W, stride_w, 0); @@ -1497,13 +1528,11 @@ struct Pool2DOptionsBuilder { void add_filter_height(int32_t filter_height) { fbb_.AddElement(Pool2DOptions::VT_FILTER_HEIGHT, filter_height, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(Pool2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit Pool2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } Pool2DOptionsBuilder &operator=(const Pool2DOptionsBuilder &); @@ -1515,11 +1544,13 @@ struct Pool2DOptionsBuilder { }; inline flatbuffers::Offset CreatePool2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, - int32_t stride_w = 0, int32_t stride_h = 0, int32_t filter_width = 0, + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t filter_width = 0, int32_t filter_height = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { Pool2DOptionsBuilder builder_(_fbb); builder_.add_filter_height(filter_height); builder_.add_filter_width(filter_width); @@ -1530,9 +1561,7 @@ inline flatbuffers::Offset CreatePool2DOptions( return builder_.Finish(); } -flatbuffers::Offset CreatePool2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { typedef DepthwiseConv2DOptions TableType; @@ -1546,11 +1575,11 @@ struct DepthwiseConv2DOptionsT : public flatbuffers::NativeTable { stride_w(0), stride_h(0), depth_multiplier(0), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef DepthwiseConv2DOptionsT NativeTableType; enum { VT_PADDING = 4, @@ -1562,14 +1591,17 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS Padding padding() const { return static_cast(GetField(VT_PADDING, 0)); } - int32_t stride_w() const { return GetField(VT_STRIDE_W, 0); } - int32_t stride_h() const { return GetField(VT_STRIDE_H, 0); } + int32_t stride_w() const { + return GetField(VT_STRIDE_W, 0); + } + int32_t stride_h() const { + return GetField(VT_STRIDE_H, 0); + } int32_t depth_multiplier() const { return GetField(VT_DEPTH_MULTIPLIER, 0); } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1580,22 +1612,16 @@ struct DepthwiseConv2DOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - DepthwiseConv2DOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - DepthwiseConv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + DepthwiseConv2DOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct DepthwiseConv2DOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_padding(Padding padding) { - fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, - static_cast(padding), 0); + fbb_.AddElement(DepthwiseConv2DOptions::VT_PADDING, static_cast(padding), 0); } void add_stride_w(int32_t stride_w) { fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_W, stride_w, 0); @@ -1604,21 +1630,16 @@ struct DepthwiseConv2DOptionsBuilder { fbb_.AddElement(DepthwiseConv2DOptions::VT_STRIDE_H, stride_h, 0); } void add_depth_multiplier(int32_t depth_multiplier) { - fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, - depth_multiplier, 0); + fbb_.AddElement(DepthwiseConv2DOptions::VT_DEPTH_MULTIPLIER, depth_multiplier, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement( - DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DepthwiseConv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit DepthwiseConv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - DepthwiseConv2DOptionsBuilder &operator=( - const DepthwiseConv2DOptionsBuilder &); + DepthwiseConv2DOptionsBuilder &operator=(const DepthwiseConv2DOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1627,10 +1648,12 @@ struct DepthwiseConv2DOptionsBuilder { }; inline flatbuffers::Offset CreateDepthwiseConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, Padding padding = Padding_SAME, - int32_t stride_w = 0, int32_t stride_h = 0, int32_t depth_multiplier = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + Padding padding = Padding_SAME, + int32_t stride_w = 0, + int32_t stride_h = 0, + int32_t depth_multiplier = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { DepthwiseConv2DOptionsBuilder builder_(_fbb); builder_.add_depth_multiplier(depth_multiplier); builder_.add_stride_h(stride_h); @@ -1640,34 +1663,33 @@ inline flatbuffers::Offset CreateDepthwiseConv2DOptions( return builder_.Finish(); } -flatbuffers::Offset CreateDepthwiseConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateDepthwiseConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ConcatEmbeddingsOptionsT : public flatbuffers::NativeTable { typedef ConcatEmbeddingsOptions TableType; int32_t num_channels; std::vector num_columns_per_channel; std::vector embedding_dim_per_channel; - ConcatEmbeddingsOptionsT() : num_channels(0) {} + ConcatEmbeddingsOptionsT() + : num_channels(0) { + } }; -struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ConcatEmbeddingsOptionsT NativeTableType; enum { VT_NUM_CHANNELS = 4, VT_NUM_COLUMNS_PER_CHANNEL = 6, VT_EMBEDDING_DIM_PER_CHANNEL = 8 }; - int32_t num_channels() const { return GetField(VT_NUM_CHANNELS, 0); } + int32_t num_channels() const { + return GetField(VT_NUM_CHANNELS, 0); + } const flatbuffers::Vector *num_columns_per_channel() const { - return GetPointer *>( - VT_NUM_COLUMNS_PER_CHANNEL); + return GetPointer *>(VT_NUM_COLUMNS_PER_CHANNEL); } const flatbuffers::Vector *embedding_dim_per_channel() const { - return GetPointer *>( - VT_EMBEDDING_DIM_PER_CHANNEL); + return GetPointer *>(VT_EMBEDDING_DIM_PER_CHANNEL); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1675,43 +1697,31 @@ struct ConcatEmbeddingsOptions FLATBUFFERS_FINAL_CLASS VerifyOffset(verifier, VT_NUM_COLUMNS_PER_CHANNEL) && verifier.Verify(num_columns_per_channel()) && VerifyOffset(verifier, VT_EMBEDDING_DIM_PER_CHANNEL) && - verifier.Verify(embedding_dim_per_channel()) && verifier.EndTable(); + verifier.Verify(embedding_dim_per_channel()) && + verifier.EndTable(); } - ConcatEmbeddingsOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ConcatEmbeddingsOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ConcatEmbeddingsOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ConcatEmbeddingsOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_num_channels(int32_t num_channels) { - fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, - num_channels, 0); + fbb_.AddElement(ConcatEmbeddingsOptions::VT_NUM_CHANNELS, num_channels, 0); } - void add_num_columns_per_channel( - flatbuffers::Offset> - num_columns_per_channel) { - fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, - num_columns_per_channel); + void add_num_columns_per_channel(flatbuffers::Offset> num_columns_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_NUM_COLUMNS_PER_CHANNEL, num_columns_per_channel); } - void add_embedding_dim_per_channel( - flatbuffers::Offset> - embedding_dim_per_channel) { - fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, - embedding_dim_per_channel); + void add_embedding_dim_per_channel(flatbuffers::Offset> embedding_dim_per_channel) { + fbb_.AddOffset(ConcatEmbeddingsOptions::VT_EMBEDDING_DIM_PER_CHANNEL, embedding_dim_per_channel); } explicit ConcatEmbeddingsOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - ConcatEmbeddingsOptionsBuilder &operator=( - const ConcatEmbeddingsOptionsBuilder &); + ConcatEmbeddingsOptionsBuilder &operator=(const ConcatEmbeddingsOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -1719,13 +1729,11 @@ struct ConcatEmbeddingsOptionsBuilder { } }; -inline flatbuffers::Offset -CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, - int32_t num_channels = 0, - flatbuffers::Offset> - num_columns_per_channel = 0, - flatbuffers::Offset> - embedding_dim_per_channel = 0) { +inline flatbuffers::Offset CreateConcatEmbeddingsOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, + flatbuffers::Offset> num_columns_per_channel = 0, + flatbuffers::Offset> embedding_dim_per_channel = 0) { ConcatEmbeddingsOptionsBuilder builder_(_fbb); builder_.add_embedding_dim_per_channel(embedding_dim_per_channel); builder_.add_num_columns_per_channel(num_columns_per_channel); @@ -1733,61 +1741,54 @@ CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, return builder_.Finish(); } -inline flatbuffers::Offset -CreateConcatEmbeddingsOptionsDirect( - flatbuffers::FlatBufferBuilder &_fbb, int32_t num_channels = 0, +inline flatbuffers::Offset CreateConcatEmbeddingsOptionsDirect( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_channels = 0, const std::vector *num_columns_per_channel = nullptr, const std::vector *embedding_dim_per_channel = nullptr) { return tflite::CreateConcatEmbeddingsOptions( - _fbb, num_channels, - num_columns_per_channel - ? _fbb.CreateVector(*num_columns_per_channel) - : 0, - embedding_dim_per_channel - ? _fbb.CreateVector(*embedding_dim_per_channel) - : 0); + _fbb, + num_channels, + num_columns_per_channel ? _fbb.CreateVector(*num_columns_per_channel) : 0, + embedding_dim_per_channel ? _fbb.CreateVector(*embedding_dim_per_channel) : 0); } -flatbuffers::Offset CreateConcatEmbeddingsOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LSHProjectionOptionsT : public flatbuffers::NativeTable { typedef LSHProjectionOptions TableType; LSHProjectionType type; - LSHProjectionOptionsT() : type(LSHProjectionType_UNKNOWN) {} + LSHProjectionOptionsT() + : type(LSHProjectionType_UNKNOWN) { + } }; -struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct LSHProjectionOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LSHProjectionOptionsT NativeTableType; - enum { VT_TYPE = 4 }; + enum { + VT_TYPE = 4 + }; LSHProjectionType type() const { return static_cast(GetField(VT_TYPE, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_TYPE) && verifier.EndTable(); + VerifyField(verifier, VT_TYPE) && + verifier.EndTable(); } - LSHProjectionOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - LSHProjectionOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + LSHProjectionOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct LSHProjectionOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_type(LSHProjectionType type) { - fbb_.AddElement(LSHProjectionOptions::VT_TYPE, - static_cast(type), 0); + fbb_.AddElement(LSHProjectionOptions::VT_TYPE, static_cast(type), 0); } explicit LSHProjectionOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } LSHProjectionOptionsBuilder &operator=(const LSHProjectionOptionsBuilder &); @@ -1806,25 +1807,29 @@ inline flatbuffers::Offset CreateLSHProjectionOptions( return builder_.Finish(); } -flatbuffers::Offset CreateLSHProjectionOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SVDFOptionsT : public flatbuffers::NativeTable { typedef SVDFOptions TableType; int32_t rank; ActivationFunctionType fused_activation_function; SVDFOptionsT() - : rank(0), fused_activation_function(ActivationFunctionType_NONE) {} + : rank(0), + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SVDFOptionsT NativeTableType; - enum { VT_RANK = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - int32_t rank() const { return GetField(VT_RANK, 0); } + enum { + VT_RANK = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t rank() const { + return GetField(VT_RANK, 0); + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1832,14 +1837,9 @@ struct SVDFOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - SVDFOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SVDFOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SVDFOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SVDFOptionsBuilder { @@ -1848,13 +1848,11 @@ struct SVDFOptionsBuilder { void add_rank(int32_t rank) { fbb_.AddElement(SVDFOptions::VT_RANK, rank, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SVDFOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SVDFOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SVDFOptionsBuilder &operator=(const SVDFOptionsBuilder &); @@ -1866,57 +1864,51 @@ struct SVDFOptionsBuilder { }; inline flatbuffers::Offset CreateSVDFOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t rank = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t rank = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { SVDFOptionsBuilder builder_(_fbb); builder_.add_rank(rank); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateSVDFOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct RNNOptionsT : public flatbuffers::NativeTable { typedef RNNOptions TableType; ActivationFunctionType fused_activation_function; - RNNOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + RNNOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct RNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef RNNOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - RNNOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - RNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + RNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct RNNOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(RNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit RNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } RNNOptionsBuilder &operator=(const RNNOptionsBuilder &); @@ -1929,16 +1921,13 @@ struct RNNOptionsBuilder { inline flatbuffers::Offset CreateRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { RNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SequenceRNNOptionsT : public flatbuffers::NativeTable { typedef SequenceRNNOptions TableType; @@ -1946,16 +1935,21 @@ struct SequenceRNNOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; SequenceRNNOptionsT() : time_major(false), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SequenceRNNOptionsT NativeTableType; - enum { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } + enum { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -1963,30 +1957,22 @@ struct SequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - SequenceRNNOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SequenceRNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SequenceRNNOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_time_major(bool time_major) { - fbb_.AddElement(SequenceRNNOptions::VT_TIME_MAJOR, - static_cast(time_major), 0); + fbb_.AddElement(SequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SequenceRNNOptionsBuilder &operator=(const SequenceRNNOptionsBuilder &); @@ -1998,18 +1984,16 @@ struct SequenceRNNOptionsBuilder { }; inline flatbuffers::Offset CreateSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + bool time_major = false, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { SequenceRNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); return builder_.Finish(); } -flatbuffers::Offset CreateSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { typedef BidirectionalSequenceRNNOptions TableType; @@ -2017,17 +2001,21 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { ActivationFunctionType fused_activation_function; BidirectionalSequenceRNNOptionsT() : time_major(false), - fused_activation_function(ActivationFunctionType_NONE) {} + fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BidirectionalSequenceRNNOptionsT NativeTableType; - enum { VT_TIME_MAJOR = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - bool time_major() const { return GetField(VT_TIME_MAJOR, 0) != 0; } + enum { + VT_TIME_MAJOR = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + bool time_major() const { + return GetField(VT_TIME_MAJOR, 0) != 0; + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -2035,37 +2023,25 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - BidirectionalSequenceRNNOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - BidirectionalSequenceRNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const BidirectionalSequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct BidirectionalSequenceRNNOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_time_major(bool time_major) { - fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR, - static_cast(time_major), 0); - } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement( - BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); - } - explicit BidirectionalSequenceRNNOptionsBuilder( - flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_TIME_MAJOR, static_cast(time_major), 0); + } + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); + } + explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - BidirectionalSequenceRNNOptionsBuilder &operator=( - const BidirectionalSequenceRNNOptionsBuilder &); + BidirectionalSequenceRNNOptionsBuilder &operator=(const BidirectionalSequenceRNNOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -2073,63 +2049,52 @@ struct BidirectionalSequenceRNNOptionsBuilder { } }; -inline flatbuffers::Offset -CreateBidirectionalSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { +inline flatbuffers::Offset CreateBidirectionalSequenceRNNOptions( + flatbuffers::FlatBufferBuilder &_fbb, + bool time_major = false, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); return builder_.Finish(); } -flatbuffers::Offset -CreateBidirectionalSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const BidirectionalSequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateBidirectionalSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct FullyConnectedOptionsT : public flatbuffers::NativeTable { typedef FullyConnectedOptions TableType; ActivationFunctionType fused_activation_function; FullyConnectedOptionsT() - : fused_activation_function(ActivationFunctionType_NONE) {} + : fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct FullyConnectedOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef FullyConnectedOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - FullyConnectedOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - FullyConnectedOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + FullyConnectedOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct FullyConnectedOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(FullyConnectedOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit FullyConnectedOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } FullyConnectedOptionsBuilder &operator=(const FullyConnectedOptionsBuilder &); @@ -2142,39 +2107,38 @@ struct FullyConnectedOptionsBuilder { inline flatbuffers::Offset CreateFullyConnectedOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { FullyConnectedOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateFullyConnectedOptions( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateFullyConnectedOptions(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SoftmaxOptionsT : public flatbuffers::NativeTable { typedef SoftmaxOptions TableType; float beta; - SoftmaxOptionsT() : beta(0.0f) {} + SoftmaxOptionsT() + : beta(0.0f) { + } }; struct SoftmaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SoftmaxOptionsT NativeTableType; - enum { VT_BETA = 4 }; - float beta() const { return GetField(VT_BETA, 0.0f); } + enum { + VT_BETA = 4 + }; + float beta() const { + return GetField(VT_BETA, 0.0f); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_BETA) && verifier.EndTable(); + VerifyField(verifier, VT_BETA) && + verifier.EndTable(); } - SoftmaxOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SoftmaxOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SoftmaxOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SoftmaxOptionsBuilder { @@ -2184,7 +2148,7 @@ struct SoftmaxOptionsBuilder { fbb_.AddElement(SoftmaxOptions::VT_BETA, beta, 0.0f); } explicit SoftmaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SoftmaxOptionsBuilder &operator=(const SoftmaxOptionsBuilder &); @@ -2196,32 +2160,36 @@ struct SoftmaxOptionsBuilder { }; inline flatbuffers::Offset CreateSoftmaxOptions( - flatbuffers::FlatBufferBuilder &_fbb, float beta = 0.0f) { + flatbuffers::FlatBufferBuilder &_fbb, + float beta = 0.0f) { SoftmaxOptionsBuilder builder_(_fbb); builder_.add_beta(beta); return builder_.Finish(); } -flatbuffers::Offset CreateSoftmaxOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ConcatenationOptionsT : public flatbuffers::NativeTable { typedef ConcatenationOptions TableType; int32_t axis; ActivationFunctionType fused_activation_function; ConcatenationOptionsT() - : axis(0), fused_activation_function(ActivationFunctionType_NONE) {} + : axis(0), + fused_activation_function(ActivationFunctionType_NONE) { + } }; -struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ConcatenationOptionsT NativeTableType; - enum { VT_AXIS = 4, VT_FUSED_ACTIVATION_FUNCTION = 6 }; - int32_t axis() const { return GetField(VT_AXIS, 0); } + enum { + VT_AXIS = 4, + VT_FUSED_ACTIVATION_FUNCTION = 6 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -2229,14 +2197,9 @@ struct ConcatenationOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - ConcatenationOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ConcatenationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ConcatenationOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ConcatenationOptionsBuilder { @@ -2245,13 +2208,11 @@ struct ConcatenationOptionsBuilder { void add_axis(int32_t axis) { fbb_.AddElement(ConcatenationOptions::VT_AXIS, axis, 0); } - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(ConcatenationOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit ConcatenationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ConcatenationOptionsBuilder &operator=(const ConcatenationOptionsBuilder &); @@ -2263,57 +2224,51 @@ struct ConcatenationOptionsBuilder { }; inline flatbuffers::Offset CreateConcatenationOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { ConcatenationOptionsBuilder builder_(_fbb); builder_.add_axis(axis); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateConcatenationOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateConcatenationOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct AddOptionsT : public flatbuffers::NativeTable { typedef AddOptions TableType; ActivationFunctionType fused_activation_function; - AddOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + AddOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct AddOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef AddOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - AddOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - AddOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + AddOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct AddOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(AddOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit AddOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } AddOptionsBuilder &operator=(const AddOptionsBuilder &); @@ -2326,55 +2281,48 @@ struct AddOptionsBuilder { inline flatbuffers::Offset CreateAddOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { AddOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateAddOptions( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct MulOptionsT : public flatbuffers::NativeTable { typedef MulOptions TableType; ActivationFunctionType fused_activation_function; - MulOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + MulOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct MulOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef MulOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - MulOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - MulOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + MulOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct MulOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(MulOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit MulOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } MulOptionsBuilder &operator=(const MulOptionsBuilder &); @@ -2387,55 +2335,48 @@ struct MulOptionsBuilder { inline flatbuffers::Offset CreateMulOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { MulOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateMulOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct L2NormOptionsT : public flatbuffers::NativeTable { typedef L2NormOptions TableType; ActivationFunctionType fused_activation_function; - L2NormOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + L2NormOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct L2NormOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef L2NormOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - L2NormOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - L2NormOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + L2NormOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct L2NormOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(L2NormOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit L2NormOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } L2NormOptionsBuilder &operator=(const L2NormOptionsBuilder &); @@ -2448,16 +2389,13 @@ struct L2NormOptionsBuilder { inline flatbuffers::Offset CreateL2NormOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { L2NormOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateL2NormOptions( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateL2NormOptions(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LocalResponseNormalizationOptionsT : public flatbuffers::NativeTable { typedef LocalResponseNormalizationOptions TableType; @@ -2466,61 +2404,66 @@ struct LocalResponseNormalizationOptionsT : public flatbuffers::NativeTable { float alpha; float beta; LocalResponseNormalizationOptionsT() - : radius(0), bias(0.0f), alpha(0.0f), beta(0.0f) {} + : radius(0), + bias(0.0f), + alpha(0.0f), + beta(0.0f) { + } }; -struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct LocalResponseNormalizationOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LocalResponseNormalizationOptionsT NativeTableType; - enum { VT_RADIUS = 4, VT_BIAS = 6, VT_ALPHA = 8, VT_BETA = 10 }; - int32_t radius() const { return GetField(VT_RADIUS, 0); } - float bias() const { return GetField(VT_BIAS, 0.0f); } - float alpha() const { return GetField(VT_ALPHA, 0.0f); } - float beta() const { return GetField(VT_BETA, 0.0f); } + enum { + VT_RADIUS = 4, + VT_BIAS = 6, + VT_ALPHA = 8, + VT_BETA = 10 + }; + int32_t radius() const { + return GetField(VT_RADIUS, 0); + } + float bias() const { + return GetField(VT_BIAS, 0.0f); + } + float alpha() const { + return GetField(VT_ALPHA, 0.0f); + } + float beta() const { + return GetField(VT_BETA, 0.0f); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_RADIUS) && VerifyField(verifier, VT_BIAS) && VerifyField(verifier, VT_ALPHA) && - VerifyField(verifier, VT_BETA) && verifier.EndTable(); + VerifyField(verifier, VT_BETA) && + verifier.EndTable(); } - LocalResponseNormalizationOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - LocalResponseNormalizationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + LocalResponseNormalizationOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct LocalResponseNormalizationOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_radius(int32_t radius) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, - radius, 0); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_RADIUS, radius, 0); } void add_bias(float bias) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, - 0.0f); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BIAS, bias, 0.0f); } void add_alpha(float alpha) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, - 0.0f); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_ALPHA, alpha, 0.0f); } void add_beta(float beta) { - fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, - 0.0f); + fbb_.AddElement(LocalResponseNormalizationOptions::VT_BETA, beta, 0.0f); } - explicit LocalResponseNormalizationOptionsBuilder( - flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit LocalResponseNormalizationOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - LocalResponseNormalizationOptionsBuilder &operator=( - const LocalResponseNormalizationOptionsBuilder &); + LocalResponseNormalizationOptionsBuilder &operator=(const LocalResponseNormalizationOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -2528,10 +2471,12 @@ struct LocalResponseNormalizationOptionsBuilder { } }; -inline flatbuffers::Offset -CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, - int32_t radius = 0, float bias = 0.0f, - float alpha = 0.0f, float beta = 0.0f) { +inline flatbuffers::Offset CreateLocalResponseNormalizationOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t radius = 0, + float bias = 0.0f, + float alpha = 0.0f, + float beta = 0.0f) { LocalResponseNormalizationOptionsBuilder builder_(_fbb); builder_.add_beta(beta); builder_.add_alpha(alpha); @@ -2540,11 +2485,7 @@ CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, return builder_.Finish(); } -flatbuffers::Offset -CreateLocalResponseNormalizationOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct LSTMOptionsT : public flatbuffers::NativeTable { typedef LSTMOptions TableType; @@ -2554,41 +2495,43 @@ struct LSTMOptionsT : public flatbuffers::NativeTable { LSTMOptionsT() : fused_activation_function(ActivationFunctionType_NONE), cell_clip(0.0f), - proj_clip(0.0f) {} + proj_clip(0.0f) { + } }; struct LSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef LSTMOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4, VT_CELL_CLIP = 6, VT_PROJ_CLIP = 8 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField(VT_PROJ_CLIP, 0.0f); } - float cell_clip() const { return GetField(VT_CELL_CLIP, 0.0f); } - float proj_clip() const { return GetField(VT_PROJ_CLIP, 0.0f); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && VerifyField(verifier, VT_CELL_CLIP) && - VerifyField(verifier, VT_PROJ_CLIP) && verifier.EndTable(); + VerifyField(verifier, VT_PROJ_CLIP) && + verifier.EndTable(); } - LSTMOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - LSTMOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + LSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct LSTMOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(LSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } void add_cell_clip(float cell_clip) { fbb_.AddElement(LSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); @@ -2597,7 +2540,7 @@ struct LSTMOptionsBuilder { fbb_.AddElement(LSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); } explicit LSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } LSTMOptionsBuilder &operator=(const LSTMOptionsBuilder &); @@ -2610,9 +2553,9 @@ struct LSTMOptionsBuilder { inline flatbuffers::Offset CreateLSTMOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE, - float cell_clip = 0.0f, float proj_clip = 0.0f) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f) { LSTMOptionsBuilder builder_(_fbb); builder_.add_proj_clip(proj_clip); builder_.add_cell_clip(cell_clip); @@ -2620,9 +2563,7 @@ inline flatbuffers::Offset CreateLSTMOptions( return builder_.Finish(); } -flatbuffers::Offset CreateLSTMOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { typedef ResizeBilinearOptions TableType; @@ -2657,7 +2598,7 @@ struct ResizeBilinearOptionsBuilder { fbb_.AddElement(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast(align_corners), 0); } explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ResizeBilinearOptionsBuilder &operator=(const ResizeBilinearOptionsBuilder &); @@ -2681,25 +2622,27 @@ flatbuffers::Offset CreateResizeBilinearOptions(flatbuffe struct CallOptionsT : public flatbuffers::NativeTable { typedef CallOptions TableType; uint32_t subgraph; - CallOptionsT() : subgraph(0) {} + CallOptionsT() + : subgraph(0) { + } }; struct CallOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef CallOptionsT NativeTableType; - enum { VT_SUBGRAPH = 4 }; - uint32_t subgraph() const { return GetField(VT_SUBGRAPH, 0); } + enum { + VT_SUBGRAPH = 4 + }; + uint32_t subgraph() const { + return GetField(VT_SUBGRAPH, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_SUBGRAPH) && verifier.EndTable(); + VerifyField(verifier, VT_SUBGRAPH) && + verifier.EndTable(); } - CallOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - CallOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + CallOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct CallOptionsBuilder { @@ -2709,7 +2652,7 @@ struct CallOptionsBuilder { fbb_.AddElement(CallOptions::VT_SUBGRAPH, subgraph, 0); } explicit CallOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } CallOptionsBuilder &operator=(const CallOptionsBuilder &); @@ -2721,41 +2664,37 @@ struct CallOptionsBuilder { }; inline flatbuffers::Offset CreateCallOptions( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t subgraph = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t subgraph = 0) { CallOptionsBuilder builder_(_fbb); builder_.add_subgraph(subgraph); return builder_.Finish(); } -flatbuffers::Offset CreateCallOptions( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateCallOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct PadOptionsT : public flatbuffers::NativeTable { typedef PadOptions TableType; - PadOptionsT() {} + PadOptionsT() { + } }; struct PadOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef PadOptionsT NativeTableType; bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + verifier.EndTable(); } - PadOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - PadOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + PadOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct PadOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; explicit PadOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } PadOptionsBuilder &operator=(const PadOptionsBuilder &); @@ -2772,45 +2711,42 @@ inline flatbuffers::Offset CreatePadOptions( return builder_.Finish(); } -flatbuffers::Offset CreatePadOptions( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreatePadOptions(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ReshapeOptionsT : public flatbuffers::NativeTable { typedef ReshapeOptions TableType; std::vector new_shape; - ReshapeOptionsT() {} + ReshapeOptionsT() { + } }; struct ReshapeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef ReshapeOptionsT NativeTableType; - enum { VT_NEW_SHAPE = 4 }; + enum { + VT_NEW_SHAPE = 4 + }; const flatbuffers::Vector *new_shape() const { return GetPointer *>(VT_NEW_SHAPE); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_NEW_SHAPE) && - verifier.Verify(new_shape()) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NEW_SHAPE) && + verifier.Verify(new_shape()) && + verifier.EndTable(); } - ReshapeOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - ReshapeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ReshapeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ReshapeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_new_shape( - flatbuffers::Offset> new_shape) { + void add_new_shape(flatbuffers::Offset> new_shape) { fbb_.AddOffset(ReshapeOptions::VT_NEW_SHAPE, new_shape); } explicit ReshapeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ReshapeOptionsBuilder &operator=(const ReshapeOptionsBuilder &); @@ -2833,39 +2769,34 @@ inline flatbuffers::Offset CreateReshapeOptionsDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *new_shape = nullptr) { return tflite::CreateReshapeOptions( - _fbb, new_shape ? _fbb.CreateVector(*new_shape) : 0); + _fbb, + new_shape ? _fbb.CreateVector(*new_shape) : 0); } -flatbuffers::Offset CreateReshapeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateReshapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SpaceToBatchNDOptionsT : public flatbuffers::NativeTable { typedef SpaceToBatchNDOptions TableType; - SpaceToBatchNDOptionsT() {} + SpaceToBatchNDOptionsT() { + } }; -struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct SpaceToBatchNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SpaceToBatchNDOptionsT NativeTableType; bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + verifier.EndTable(); } - SpaceToBatchNDOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SpaceToBatchNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SpaceToBatchNDOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToBatchNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SpaceToBatchNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; explicit SpaceToBatchNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SpaceToBatchNDOptionsBuilder &operator=(const SpaceToBatchNDOptionsBuilder &); @@ -2882,36 +2813,30 @@ inline flatbuffers::Offset CreateSpaceToBatchNDOptions( return builder_.Finish(); } -flatbuffers::Offset CreateSpaceToBatchNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSpaceToBatchNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct BatchToSpaceNDOptionsT : public flatbuffers::NativeTable { typedef BatchToSpaceNDOptions TableType; - BatchToSpaceNDOptionsT() {} + BatchToSpaceNDOptionsT() { + } }; -struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct BatchToSpaceNDOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BatchToSpaceNDOptionsT NativeTableType; bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + verifier.EndTable(); } - BatchToSpaceNDOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - BatchToSpaceNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + BatchToSpaceNDOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BatchToSpaceNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct BatchToSpaceNDOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; explicit BatchToSpaceNDOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } BatchToSpaceNDOptionsBuilder &operator=(const BatchToSpaceNDOptionsBuilder &); @@ -2928,9 +2853,7 @@ inline flatbuffers::Offset CreateBatchToSpaceNDOptions( return builder_.Finish(); } -flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateBatchToSpaceNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SkipGramOptionsT : public flatbuffers::NativeTable { typedef SkipGramOptions TableType; @@ -2938,13 +2861,22 @@ struct SkipGramOptionsT : public flatbuffers::NativeTable { int32_t max_skip_size; bool include_all_ngrams; SkipGramOptionsT() - : ngram_size(0), max_skip_size(0), include_all_ngrams(false) {} + : ngram_size(0), + max_skip_size(0), + include_all_ngrams(false) { + } }; struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SkipGramOptionsT NativeTableType; - enum { VT_NGRAM_SIZE = 4, VT_MAX_SKIP_SIZE = 6, VT_INCLUDE_ALL_NGRAMS = 8 }; - int32_t ngram_size() const { return GetField(VT_NGRAM_SIZE, 0); } + enum { + VT_NGRAM_SIZE = 4, + VT_MAX_SKIP_SIZE = 6, + VT_INCLUDE_ALL_NGRAMS = 8 + }; + int32_t ngram_size() const { + return GetField(VT_NGRAM_SIZE, 0); + } int32_t max_skip_size() const { return GetField(VT_MAX_SKIP_SIZE, 0); } @@ -2958,14 +2890,9 @@ struct SkipGramOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VerifyField(verifier, VT_INCLUDE_ALL_NGRAMS) && verifier.EndTable(); } - SkipGramOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SkipGramOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SkipGramOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SkipGramOptionsBuilder { @@ -2975,15 +2902,13 @@ struct SkipGramOptionsBuilder { fbb_.AddElement(SkipGramOptions::VT_NGRAM_SIZE, ngram_size, 0); } void add_max_skip_size(int32_t max_skip_size) { - fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, - 0); + fbb_.AddElement(SkipGramOptions::VT_MAX_SKIP_SIZE, max_skip_size, 0); } void add_include_all_ngrams(bool include_all_ngrams) { - fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, - static_cast(include_all_ngrams), 0); + fbb_.AddElement(SkipGramOptions::VT_INCLUDE_ALL_NGRAMS, static_cast(include_all_ngrams), 0); } explicit SkipGramOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SkipGramOptionsBuilder &operator=(const SkipGramOptionsBuilder &); @@ -2995,8 +2920,10 @@ struct SkipGramOptionsBuilder { }; inline flatbuffers::Offset CreateSkipGramOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t ngram_size = 0, - int32_t max_skip_size = 0, bool include_all_ngrams = false) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t ngram_size = 0, + int32_t max_skip_size = 0, + bool include_all_ngrams = false) { SkipGramOptionsBuilder builder_(_fbb); builder_.add_max_skip_size(max_skip_size); builder_.add_ngram_size(ngram_size); @@ -3004,33 +2931,32 @@ inline flatbuffers::Offset CreateSkipGramOptions( return builder_.Finish(); } -flatbuffers::Offset CreateSkipGramOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSkipGramOptions(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SpaceToDepthOptionsT : public flatbuffers::NativeTable { typedef SpaceToDepthOptions TableType; int32_t block_size; - SpaceToDepthOptionsT() : block_size(0) {} + SpaceToDepthOptionsT() + : block_size(0) { + } }; -struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct SpaceToDepthOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SpaceToDepthOptionsT NativeTableType; - enum { VT_BLOCK_SIZE = 4 }; - int32_t block_size() const { return GetField(VT_BLOCK_SIZE, 0); } + enum { + VT_BLOCK_SIZE = 4 + }; + int32_t block_size() const { + return GetField(VT_BLOCK_SIZE, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_BLOCK_SIZE) && verifier.EndTable(); + VerifyField(verifier, VT_BLOCK_SIZE) && + verifier.EndTable(); } - SpaceToDepthOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SpaceToDepthOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SpaceToDepthOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SpaceToDepthOptionsBuilder { @@ -3040,7 +2966,7 @@ struct SpaceToDepthOptionsBuilder { fbb_.AddElement(SpaceToDepthOptions::VT_BLOCK_SIZE, block_size, 0); } explicit SpaceToDepthOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SpaceToDepthOptionsBuilder &operator=(const SpaceToDepthOptionsBuilder &); @@ -3052,54 +2978,49 @@ struct SpaceToDepthOptionsBuilder { }; inline flatbuffers::Offset CreateSpaceToDepthOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t block_size = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t block_size = 0) { SpaceToDepthOptionsBuilder builder_(_fbb); builder_.add_block_size(block_size); return builder_.Finish(); } -flatbuffers::Offset CreateSpaceToDepthOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSpaceToDepthOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SubOptionsT : public flatbuffers::NativeTable { typedef SubOptions TableType; ActivationFunctionType fused_activation_function; - SubOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + SubOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct SubOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SubOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - SubOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SubOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SubOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SubOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(SubOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit SubOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SubOptionsBuilder &operator=(const SubOptionsBuilder &); @@ -3112,55 +3033,48 @@ struct SubOptionsBuilder { inline flatbuffers::Offset CreateSubOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { SubOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateSubOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSubOptions(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct DivOptionsT : public flatbuffers::NativeTable { typedef DivOptions TableType; ActivationFunctionType fused_activation_function; - DivOptionsT() : fused_activation_function(ActivationFunctionType_NONE) {} + DivOptionsT() + : fused_activation_function(ActivationFunctionType_NONE) { + } }; struct DivOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef DivOptionsT NativeTableType; - enum { VT_FUSED_ACTIVATION_FUNCTION = 4 }; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4 + }; ActivationFunctionType fused_activation_function() const { - return static_cast( - GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); + return static_cast(GetField(VT_FUSED_ACTIVATION_FUNCTION, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_FUSED_ACTIVATION_FUNCTION) && verifier.EndTable(); } - DivOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - DivOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + DivOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct DivOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_fused_activation_function( - ActivationFunctionType fused_activation_function) { - fbb_.AddElement(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, - static_cast(fused_activation_function), 0); + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement(DivOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast(fused_activation_function), 0); } explicit DivOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } DivOptionsBuilder &operator=(const DivOptionsBuilder &); @@ -3173,59 +3087,91 @@ struct DivOptionsBuilder { inline flatbuffers::Offset CreateDivOptions( flatbuffers::FlatBufferBuilder &_fbb, - ActivationFunctionType fused_activation_function = - ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { DivOptionsBuilder builder_(_fbb); builder_.add_fused_activation_function(fused_activation_function); return builder_.Finish(); } -flatbuffers::Offset CreateDivOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct TopKV2OptionsT : public flatbuffers::NativeTable { + typedef TopKV2Options TableType; + TopKV2OptionsT() { + } +}; + +struct TopKV2Options FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef TopKV2OptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + TopKV2OptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TopKV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct TopKV2OptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit TopKV2OptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + TopKV2OptionsBuilder &operator=(const TopKV2OptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateTopKV2Options( + flatbuffers::FlatBufferBuilder &_fbb) { + TopKV2OptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateTopKV2Options(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct EmbeddingLookupSparseOptionsT : public flatbuffers::NativeTable { typedef EmbeddingLookupSparseOptions TableType; CombinerType combiner; - EmbeddingLookupSparseOptionsT() : combiner(CombinerType_SUM) {} + EmbeddingLookupSparseOptionsT() + : combiner(CombinerType_SUM) { + } }; -struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct EmbeddingLookupSparseOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef EmbeddingLookupSparseOptionsT NativeTableType; - enum { VT_COMBINER = 4 }; + enum { + VT_COMBINER = 4 + }; CombinerType combiner() const { return static_cast(GetField(VT_COMBINER, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_COMBINER) && verifier.EndTable(); + VerifyField(verifier, VT_COMBINER) && + verifier.EndTable(); } - EmbeddingLookupSparseOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + EmbeddingLookupSparseOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct EmbeddingLookupSparseOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_combiner(CombinerType combiner) { - fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, - static_cast(combiner), 0); + fbb_.AddElement(EmbeddingLookupSparseOptions::VT_COMBINER, static_cast(combiner), 0); } - explicit EmbeddingLookupSparseOptionsBuilder( - flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + explicit EmbeddingLookupSparseOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } - EmbeddingLookupSparseOptionsBuilder &operator=( - const EmbeddingLookupSparseOptionsBuilder &); + EmbeddingLookupSparseOptionsBuilder &operator=(const EmbeddingLookupSparseOptionsBuilder &); flatbuffers::Offset Finish() { const auto end = fbb_.EndTable(start_); auto o = flatbuffers::Offset(end); @@ -3233,42 +3179,40 @@ struct EmbeddingLookupSparseOptionsBuilder { } }; -inline flatbuffers::Offset -CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, - CombinerType combiner = CombinerType_SUM) { +inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions( + flatbuffers::FlatBufferBuilder &_fbb, + CombinerType combiner = CombinerType_SUM) { EmbeddingLookupSparseOptionsBuilder builder_(_fbb); builder_.add_combiner(combiner); return builder_.Finish(); } -flatbuffers::Offset -CreateEmbeddingLookupSparseOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct GatherOptionsT : public flatbuffers::NativeTable { typedef GatherOptions TableType; int32_t axis; - GatherOptionsT() : axis(0) {} + GatherOptionsT() + : axis(0) { + } }; struct GatherOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef GatherOptionsT NativeTableType; - enum { VT_AXIS = 4 }; - int32_t axis() const { return GetField(VT_AXIS, 0); } + enum { + VT_AXIS = 4 + }; + int32_t axis() const { + return GetField(VT_AXIS, 0); + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_AXIS) && verifier.EndTable(); + VerifyField(verifier, VT_AXIS) && + verifier.EndTable(); } - GatherOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - GatherOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + GatherOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(GatherOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct GatherOptionsBuilder { @@ -3278,7 +3222,7 @@ struct GatherOptionsBuilder { fbb_.AddElement(GatherOptions::VT_AXIS, axis, 0); } explicit GatherOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } GatherOptionsBuilder &operator=(const GatherOptionsBuilder &); @@ -3290,41 +3234,37 @@ struct GatherOptionsBuilder { }; inline flatbuffers::Offset CreateGatherOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0) { + flatbuffers::FlatBufferBuilder &_fbb, + int32_t axis = 0) { GatherOptionsBuilder builder_(_fbb); builder_.add_axis(axis); return builder_.Finish(); } -flatbuffers::Offset CreateGatherOptions( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateGatherOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct TransposeOptionsT : public flatbuffers::NativeTable { typedef TransposeOptions TableType; - TransposeOptionsT() {} + TransposeOptionsT() { + } }; struct TransposeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef TransposeOptionsT NativeTableType; bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && verifier.EndTable(); + return VerifyTableStart(verifier) && + verifier.EndTable(); } - TransposeOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - TransposeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + TransposeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(TransposeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct TransposeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; explicit TransposeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } TransposeOptionsBuilder &operator=(const TransposeOptionsBuilder &); @@ -3341,43 +3281,82 @@ inline flatbuffers::Offset CreateTransposeOptions( return builder_.Finish(); } -flatbuffers::Offset CreateTransposeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateTransposeOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct ExpOptionsT : public flatbuffers::NativeTable { + typedef ExpOptions TableType; + ExpOptionsT() { + } +}; + +struct ExpOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ExpOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + ExpOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ExpOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct ExpOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit ExpOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ExpOptionsBuilder &operator=(const ExpOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateExpOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + ExpOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct MeanOptionsT : public flatbuffers::NativeTable { typedef MeanOptions TableType; bool keep_dims; - MeanOptionsT() : keep_dims(false) {} + MeanOptionsT() + : keep_dims(false) { + } }; struct MeanOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef MeanOptionsT NativeTableType; - enum { VT_KEEP_DIMS = 4 }; - bool keep_dims() const { return GetField(VT_KEEP_DIMS, 0) != 0; } + enum { + VT_KEEP_DIMS = 4 + }; + bool keep_dims() const { + return GetField(VT_KEEP_DIMS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && - VerifyField(verifier, VT_KEEP_DIMS) && verifier.EndTable(); + VerifyField(verifier, VT_KEEP_DIMS) && + verifier.EndTable(); } - MeanOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - MeanOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + MeanOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct MeanOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_keep_dims(bool keep_dims) { - fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, - static_cast(keep_dims), 0); + fbb_.AddElement(MeanOptions::VT_KEEP_DIMS, static_cast(keep_dims), 0); } explicit MeanOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } MeanOptionsBuilder &operator=(const MeanOptionsBuilder &); @@ -3389,52 +3368,49 @@ struct MeanOptionsBuilder { }; inline flatbuffers::Offset CreateMeanOptions( - flatbuffers::FlatBufferBuilder &_fbb, bool keep_dims = false) { + flatbuffers::FlatBufferBuilder &_fbb, + bool keep_dims = false) { MeanOptionsBuilder builder_(_fbb); builder_.add_keep_dims(keep_dims); return builder_.Finish(); } -flatbuffers::Offset CreateMeanOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SqueezeOptionsT : public flatbuffers::NativeTable { typedef SqueezeOptions TableType; std::vector squeeze_dims; - SqueezeOptionsT() {} + SqueezeOptionsT() { + } }; struct SqueezeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef SqueezeOptionsT NativeTableType; - enum { VT_SQUEEZE_DIMS = 4 }; + enum { + VT_SQUEEZE_DIMS = 4 + }; const flatbuffers::Vector *squeeze_dims() const { return GetPointer *>(VT_SQUEEZE_DIMS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_SQUEEZE_DIMS) && - verifier.Verify(squeeze_dims()) && verifier.EndTable(); + verifier.Verify(squeeze_dims()) && + verifier.EndTable(); } - SqueezeOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SqueezeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SqueezeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SqueezeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SqueezeOptionsBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_squeeze_dims( - flatbuffers::Offset> squeeze_dims) { + void add_squeeze_dims(flatbuffers::Offset> squeeze_dims) { fbb_.AddOffset(SqueezeOptions::VT_SQUEEZE_DIMS, squeeze_dims); } explicit SqueezeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SqueezeOptionsBuilder &operator=(const SqueezeOptionsBuilder &); @@ -3457,30 +3433,83 @@ inline flatbuffers::Offset CreateSqueezeOptionsDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *squeeze_dims = nullptr) { return tflite::CreateSqueezeOptions( - _fbb, squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0); + _fbb, + squeeze_dims ? _fbb.CreateVector(*squeeze_dims) : 0); } -flatbuffers::Offset CreateSqueezeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSqueezeOptions(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -struct StridedSliceOptionsT : public flatbuffers::NativeTable { - typedef StridedSliceOptions TableType; - int32_t begin_mask; - int32_t end_mask; - int32_t ellipsis_mask; - int32_t new_axis_mask; - int32_t shrink_axis_mask; - StridedSliceOptionsT() - : begin_mask(0), - end_mask(0), - ellipsis_mask(0), +struct SplitOptionsT : public flatbuffers::NativeTable { + typedef SplitOptions TableType; + int32_t num_splits; + SplitOptionsT() + : num_splits(0) { + } +}; + +struct SplitOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef SplitOptionsT NativeTableType; + enum { + VT_NUM_SPLITS = 4 + }; + int32_t num_splits() const { + return GetField(VT_NUM_SPLITS, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_NUM_SPLITS) && + verifier.EndTable(); + } + SplitOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SplitOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct SplitOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_num_splits(int32_t num_splits) { + fbb_.AddElement(SplitOptions::VT_NUM_SPLITS, num_splits, 0); + } + explicit SplitOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + SplitOptionsBuilder &operator=(const SplitOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateSplitOptions( + flatbuffers::FlatBufferBuilder &_fbb, + int32_t num_splits = 0) { + SplitOptionsBuilder builder_(_fbb); + builder_.add_num_splits(num_splits); + return builder_.Finish(); +} + +flatbuffers::Offset CreateSplitOptions(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct StridedSliceOptionsT : public flatbuffers::NativeTable { + typedef StridedSliceOptions TableType; + int32_t begin_mask; + int32_t end_mask; + int32_t ellipsis_mask; + int32_t new_axis_mask; + int32_t shrink_axis_mask; + StridedSliceOptionsT() + : begin_mask(0), + end_mask(0), + ellipsis_mask(0), new_axis_mask(0), - shrink_axis_mask(0) {} + shrink_axis_mask(0) { + } }; -struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS - : private flatbuffers::Table { +struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef StridedSliceOptionsT NativeTableType; enum { VT_BEGIN_MASK = 4, @@ -3489,8 +3518,12 @@ struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS VT_NEW_AXIS_MASK = 10, VT_SHRINK_AXIS_MASK = 12 }; - int32_t begin_mask() const { return GetField(VT_BEGIN_MASK, 0); } - int32_t end_mask() const { return GetField(VT_END_MASK, 0); } + int32_t begin_mask() const { + return GetField(VT_BEGIN_MASK, 0); + } + int32_t end_mask() const { + return GetField(VT_END_MASK, 0); + } int32_t ellipsis_mask() const { return GetField(VT_ELLIPSIS_MASK, 0); } @@ -3509,14 +3542,9 @@ struct StridedSliceOptions FLATBUFFERS_FINAL_CLASS VerifyField(verifier, VT_SHRINK_AXIS_MASK) && verifier.EndTable(); } - StridedSliceOptionsT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - StridedSliceOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + StridedSliceOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(StridedSliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct StridedSliceOptionsBuilder { @@ -3529,19 +3557,16 @@ struct StridedSliceOptionsBuilder { fbb_.AddElement(StridedSliceOptions::VT_END_MASK, end_mask, 0); } void add_ellipsis_mask(int32_t ellipsis_mask) { - fbb_.AddElement(StridedSliceOptions::VT_ELLIPSIS_MASK, - ellipsis_mask, 0); + fbb_.AddElement(StridedSliceOptions::VT_ELLIPSIS_MASK, ellipsis_mask, 0); } void add_new_axis_mask(int32_t new_axis_mask) { - fbb_.AddElement(StridedSliceOptions::VT_NEW_AXIS_MASK, - new_axis_mask, 0); + fbb_.AddElement(StridedSliceOptions::VT_NEW_AXIS_MASK, new_axis_mask, 0); } void add_shrink_axis_mask(int32_t shrink_axis_mask) { - fbb_.AddElement(StridedSliceOptions::VT_SHRINK_AXIS_MASK, - shrink_axis_mask, 0); + fbb_.AddElement(StridedSliceOptions::VT_SHRINK_AXIS_MASK, shrink_axis_mask, 0); } explicit StridedSliceOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } StridedSliceOptionsBuilder &operator=(const StridedSliceOptionsBuilder &); @@ -3553,8 +3578,11 @@ struct StridedSliceOptionsBuilder { }; inline flatbuffers::Offset CreateStridedSliceOptions( - flatbuffers::FlatBufferBuilder &_fbb, int32_t begin_mask = 0, - int32_t end_mask = 0, int32_t ellipsis_mask = 0, int32_t new_axis_mask = 0, + flatbuffers::FlatBufferBuilder &_fbb, + int32_t begin_mask = 0, + int32_t end_mask = 0, + int32_t ellipsis_mask = 0, + int32_t new_axis_mask = 0, int32_t shrink_axis_mask = 0) { StridedSliceOptionsBuilder builder_(_fbb); builder_.add_shrink_axis_mask(shrink_axis_mask); @@ -3565,20 +3593,63 @@ inline flatbuffers::Offset CreateStridedSliceOptions( return builder_.Finish(); } -flatbuffers::Offset CreateStridedSliceOptions( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateStridedSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + +struct LogSoftmaxOptionsT : public flatbuffers::NativeTable { + typedef LogSoftmaxOptions TableType; + LogSoftmaxOptionsT() { + } +}; + +struct LogSoftmaxOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef LogSoftmaxOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + LogSoftmaxOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(LogSoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct LogSoftmaxOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit LogSoftmaxOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + LogSoftmaxOptionsBuilder &operator=(const LogSoftmaxOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateLogSoftmaxOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + LogSoftmaxOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateLogSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; std::string custom_code; - OperatorCodeT() : builtin_code(BuiltinOperator_ADD) {} + OperatorCodeT() + : builtin_code(BuiltinOperator_ADD) { + } }; struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef OperatorCodeT NativeTableType; - enum { VT_BUILTIN_CODE = 4, VT_CUSTOM_CODE = 6 }; + enum { + VT_BUILTIN_CODE = 4, + VT_CUSTOM_CODE = 6 + }; BuiltinOperator builtin_code() const { return static_cast(GetField(VT_BUILTIN_CODE, 0)); } @@ -3589,30 +3660,25 @@ struct OperatorCode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return VerifyTableStart(verifier) && VerifyField(verifier, VT_BUILTIN_CODE) && VerifyOffset(verifier, VT_CUSTOM_CODE) && - verifier.Verify(custom_code()) && verifier.EndTable(); + verifier.Verify(custom_code()) && + verifier.EndTable(); } - OperatorCodeT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - OperatorCodeT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + OperatorCodeT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct OperatorCodeBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; void add_builtin_code(BuiltinOperator builtin_code) { - fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, - static_cast(builtin_code), 0); + fbb_.AddElement(OperatorCode::VT_BUILTIN_CODE, static_cast(builtin_code), 0); } void add_custom_code(flatbuffers::Offset custom_code) { fbb_.AddOffset(OperatorCode::VT_CUSTOM_CODE, custom_code); } explicit OperatorCodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) - : fbb_(_fbb) { + : fbb_(_fbb) { start_ = fbb_.StartTable(); } OperatorCodeBuilder &operator=(const OperatorCodeBuilder &); @@ -3638,12 +3704,12 @@ inline flatbuffers::Offset CreateOperatorCodeDirect( BuiltinOperator builtin_code = BuiltinOperator_ADD, const char *custom_code = nullptr) { return tflite::CreateOperatorCode( - _fbb, builtin_code, custom_code ? _fbb.CreateString(custom_code) : 0); + _fbb, + builtin_code, + custom_code ? _fbb.CreateString(custom_code) : 0); } -flatbuffers::Offset CreateOperatorCode( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct OperatorT : public flatbuffers::NativeTable { typedef Operator TableType; @@ -3655,7 +3721,8 @@ struct OperatorT : public flatbuffers::NativeTable { CustomOptionsFormat custom_options_format; OperatorT() : opcode_index(0), - custom_options_format(CustomOptionsFormat_FLEXBUFFERS) {} + custom_options_format(CustomOptionsFormat_FLEXBUFFERS) { + } }; struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -3679,398 +3746,290 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer *>(VT_OUTPUTS); } BuiltinOptions builtin_options_type() const { - return static_cast( - GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); + return static_cast(GetField(VT_BUILTIN_OPTIONS_TYPE, 0)); } const void *builtin_options() const { return GetPointer(VT_BUILTIN_OPTIONS); } - template - const T *builtin_options_as() const; + template const T *builtin_options_as() const; const Conv2DOptions *builtin_options_as_Conv2DOptions() const { - return builtin_options_type() == BuiltinOptions_Conv2DOptions - ? static_cast(builtin_options()) - : nullptr; - } - const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() - const { - return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions - ? static_cast(builtin_options()) - : nullptr; - } - const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() - const { - return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_Conv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const DepthwiseConv2DOptions *builtin_options_as_DepthwiseConv2DOptions() const { + return builtin_options_type() == BuiltinOptions_DepthwiseConv2DOptions ? static_cast(builtin_options()) : nullptr; + } + const ConcatEmbeddingsOptions *builtin_options_as_ConcatEmbeddingsOptions() const { + return builtin_options_type() == BuiltinOptions_ConcatEmbeddingsOptions ? static_cast(builtin_options()) : nullptr; } const LSHProjectionOptions *builtin_options_as_LSHProjectionOptions() const { - return builtin_options_type() == BuiltinOptions_LSHProjectionOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_LSHProjectionOptions ? static_cast(builtin_options()) : nullptr; } const Pool2DOptions *builtin_options_as_Pool2DOptions() const { - return builtin_options_type() == BuiltinOptions_Pool2DOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_Pool2DOptions ? static_cast(builtin_options()) : nullptr; } const SVDFOptions *builtin_options_as_SVDFOptions() const { - return builtin_options_type() == BuiltinOptions_SVDFOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SVDFOptions ? static_cast(builtin_options()) : nullptr; } const RNNOptions *builtin_options_as_RNNOptions() const { - return builtin_options_type() == BuiltinOptions_RNNOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_RNNOptions ? static_cast(builtin_options()) : nullptr; } - const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() - const { - return builtin_options_type() == BuiltinOptions_FullyConnectedOptions - ? static_cast(builtin_options()) - : nullptr; + const FullyConnectedOptions *builtin_options_as_FullyConnectedOptions() const { + return builtin_options_type() == BuiltinOptions_FullyConnectedOptions ? static_cast(builtin_options()) : nullptr; } const SoftmaxOptions *builtin_options_as_SoftmaxOptions() const { - return builtin_options_type() == BuiltinOptions_SoftmaxOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SoftmaxOptions ? static_cast(builtin_options()) : nullptr; } const ConcatenationOptions *builtin_options_as_ConcatenationOptions() const { - return builtin_options_type() == BuiltinOptions_ConcatenationOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_ConcatenationOptions ? static_cast(builtin_options()) : nullptr; } const AddOptions *builtin_options_as_AddOptions() const { - return builtin_options_type() == BuiltinOptions_AddOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_AddOptions ? static_cast(builtin_options()) : nullptr; } const L2NormOptions *builtin_options_as_L2NormOptions() const { - return builtin_options_type() == BuiltinOptions_L2NormOptions - ? static_cast(builtin_options()) - : nullptr; - } - const LocalResponseNormalizationOptions * - builtin_options_as_LocalResponseNormalizationOptions() const { - return builtin_options_type() == - BuiltinOptions_LocalResponseNormalizationOptions - ? static_cast( - builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_L2NormOptions ? static_cast(builtin_options()) : nullptr; + } + const LocalResponseNormalizationOptions *builtin_options_as_LocalResponseNormalizationOptions() const { + return builtin_options_type() == BuiltinOptions_LocalResponseNormalizationOptions ? static_cast(builtin_options()) : nullptr; } const LSTMOptions *builtin_options_as_LSTMOptions() const { - return builtin_options_type() == BuiltinOptions_LSTMOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_LSTMOptions ? static_cast(builtin_options()) : nullptr; } - const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() - const { - return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions - ? static_cast(builtin_options()) - : nullptr; + const ResizeBilinearOptions *builtin_options_as_ResizeBilinearOptions() const { + return builtin_options_type() == BuiltinOptions_ResizeBilinearOptions ? static_cast(builtin_options()) : nullptr; } const CallOptions *builtin_options_as_CallOptions() const { - return builtin_options_type() == BuiltinOptions_CallOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_CallOptions ? static_cast(builtin_options()) : nullptr; } const ReshapeOptions *builtin_options_as_ReshapeOptions() const { - return builtin_options_type() == BuiltinOptions_ReshapeOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_ReshapeOptions ? static_cast(builtin_options()) : nullptr; } const SkipGramOptions *builtin_options_as_SkipGramOptions() const { - return builtin_options_type() == BuiltinOptions_SkipGramOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SkipGramOptions ? static_cast(builtin_options()) : nullptr; } const SpaceToDepthOptions *builtin_options_as_SpaceToDepthOptions() const { - return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SpaceToDepthOptions ? static_cast(builtin_options()) : nullptr; } - const EmbeddingLookupSparseOptions * - builtin_options_as_EmbeddingLookupSparseOptions() const { - return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions - ? static_cast( - builtin_options()) - : nullptr; + const EmbeddingLookupSparseOptions *builtin_options_as_EmbeddingLookupSparseOptions() const { + return builtin_options_type() == BuiltinOptions_EmbeddingLookupSparseOptions ? static_cast(builtin_options()) : nullptr; } const MulOptions *builtin_options_as_MulOptions() const { - return builtin_options_type() == BuiltinOptions_MulOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_MulOptions ? static_cast(builtin_options()) : nullptr; } const PadOptions *builtin_options_as_PadOptions() const { - return builtin_options_type() == BuiltinOptions_PadOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_PadOptions ? static_cast(builtin_options()) : nullptr; } const GatherOptions *builtin_options_as_GatherOptions() const { - return builtin_options_type() == BuiltinOptions_GatherOptions - ? static_cast(builtin_options()) - : nullptr; - } - const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() - const { - return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions - ? static_cast(builtin_options()) - : nullptr; - } - const SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() - const { - return builtin_options_type() == BuiltinOptions_SpaceToBatchNDOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_GatherOptions ? static_cast(builtin_options()) : nullptr; + } + const BatchToSpaceNDOptions *builtin_options_as_BatchToSpaceNDOptions() const { + return builtin_options_type() == BuiltinOptions_BatchToSpaceNDOptions ? static_cast(builtin_options()) : nullptr; + } + const SpaceToBatchNDOptions *builtin_options_as_SpaceToBatchNDOptions() const { + return builtin_options_type() == BuiltinOptions_SpaceToBatchNDOptions ? static_cast(builtin_options()) : nullptr; } const TransposeOptions *builtin_options_as_TransposeOptions() const { - return builtin_options_type() == BuiltinOptions_TransposeOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_TransposeOptions ? static_cast(builtin_options()) : nullptr; } const MeanOptions *builtin_options_as_MeanOptions() const { - return builtin_options_type() == BuiltinOptions_MeanOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_MeanOptions ? static_cast(builtin_options()) : nullptr; } const SubOptions *builtin_options_as_SubOptions() const { - return builtin_options_type() == BuiltinOptions_SubOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SubOptions ? static_cast(builtin_options()) : nullptr; } const DivOptions *builtin_options_as_DivOptions() const { - return builtin_options_type() == BuiltinOptions_DivOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_DivOptions ? static_cast(builtin_options()) : nullptr; } const SqueezeOptions *builtin_options_as_SqueezeOptions() const { - return builtin_options_type() == BuiltinOptions_SqueezeOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SqueezeOptions ? static_cast(builtin_options()) : nullptr; } const SequenceRNNOptions *builtin_options_as_SequenceRNNOptions() const { - return builtin_options_type() == BuiltinOptions_SequenceRNNOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_SequenceRNNOptions ? static_cast(builtin_options()) : nullptr; } const StridedSliceOptions *builtin_options_as_StridedSliceOptions() const { - return builtin_options_type() == BuiltinOptions_StridedSliceOptions - ? static_cast(builtin_options()) - : nullptr; + return builtin_options_type() == BuiltinOptions_StridedSliceOptions ? static_cast(builtin_options()) : nullptr; + } + const ExpOptions *builtin_options_as_ExpOptions() const { + return builtin_options_type() == BuiltinOptions_ExpOptions ? static_cast(builtin_options()) : nullptr; + } + const TopKV2Options *builtin_options_as_TopKV2Options() const { + return builtin_options_type() == BuiltinOptions_TopKV2Options ? static_cast(builtin_options()) : nullptr; + } + const SplitOptions *builtin_options_as_SplitOptions() const { + return builtin_options_type() == BuiltinOptions_SplitOptions ? static_cast(builtin_options()) : nullptr; + } + const LogSoftmaxOptions *builtin_options_as_LogSoftmaxOptions() const { + return builtin_options_type() == BuiltinOptions_LogSoftmaxOptions ? static_cast(builtin_options()) : nullptr; } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } CustomOptionsFormat custom_options_format() const { - return static_cast( - GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); + return static_cast(GetField(VT_CUSTOM_OPTIONS_FORMAT, 0)); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField(verifier, VT_OPCODE_INDEX) && - VerifyOffset(verifier, VT_INPUTS) && verifier.Verify(inputs()) && - VerifyOffset(verifier, VT_OUTPUTS) && verifier.Verify(outputs()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.Verify(outputs()) && VerifyField(verifier, VT_BUILTIN_OPTIONS_TYPE) && VerifyOffset(verifier, VT_BUILTIN_OPTIONS) && - VerifyBuiltinOptions(verifier, builtin_options(), - builtin_options_type()) && + VerifyBuiltinOptions(verifier, builtin_options(), builtin_options_type()) && VerifyOffset(verifier, VT_CUSTOM_OPTIONS) && verifier.Verify(custom_options()) && VerifyField(verifier, VT_CUSTOM_OPTIONS_FORMAT) && verifier.EndTable(); } - OperatorT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - OperatorT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + OperatorT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; -template <> -inline const Conv2DOptions *Operator::builtin_options_as() - const { +template<> inline const Conv2DOptions *Operator::builtin_options_as() const { return builtin_options_as_Conv2DOptions(); } -template <> -inline const DepthwiseConv2DOptions * -Operator::builtin_options_as() const { +template<> inline const DepthwiseConv2DOptions *Operator::builtin_options_as() const { return builtin_options_as_DepthwiseConv2DOptions(); } -template <> -inline const ConcatEmbeddingsOptions * -Operator::builtin_options_as() const { +template<> inline const ConcatEmbeddingsOptions *Operator::builtin_options_as() const { return builtin_options_as_ConcatEmbeddingsOptions(); } -template <> -inline const LSHProjectionOptions * -Operator::builtin_options_as() const { +template<> inline const LSHProjectionOptions *Operator::builtin_options_as() const { return builtin_options_as_LSHProjectionOptions(); } -template <> -inline const Pool2DOptions *Operator::builtin_options_as() - const { +template<> inline const Pool2DOptions *Operator::builtin_options_as() const { return builtin_options_as_Pool2DOptions(); } -template <> -inline const SVDFOptions *Operator::builtin_options_as() const { +template<> inline const SVDFOptions *Operator::builtin_options_as() const { return builtin_options_as_SVDFOptions(); } -template <> -inline const RNNOptions *Operator::builtin_options_as() const { +template<> inline const RNNOptions *Operator::builtin_options_as() const { return builtin_options_as_RNNOptions(); } -template <> -inline const FullyConnectedOptions * -Operator::builtin_options_as() const { +template<> inline const FullyConnectedOptions *Operator::builtin_options_as() const { return builtin_options_as_FullyConnectedOptions(); } -template <> -inline const SoftmaxOptions *Operator::builtin_options_as() - const { +template<> inline const SoftmaxOptions *Operator::builtin_options_as() const { return builtin_options_as_SoftmaxOptions(); } -template <> -inline const ConcatenationOptions * -Operator::builtin_options_as() const { +template<> inline const ConcatenationOptions *Operator::builtin_options_as() const { return builtin_options_as_ConcatenationOptions(); } -template <> -inline const AddOptions *Operator::builtin_options_as() const { +template<> inline const AddOptions *Operator::builtin_options_as() const { return builtin_options_as_AddOptions(); } -template <> -inline const L2NormOptions *Operator::builtin_options_as() - const { +template<> inline const L2NormOptions *Operator::builtin_options_as() const { return builtin_options_as_L2NormOptions(); } -template <> -inline const LocalResponseNormalizationOptions * -Operator::builtin_options_as() const { +template<> inline const LocalResponseNormalizationOptions *Operator::builtin_options_as() const { return builtin_options_as_LocalResponseNormalizationOptions(); } -template <> -inline const LSTMOptions *Operator::builtin_options_as() const { +template<> inline const LSTMOptions *Operator::builtin_options_as() const { return builtin_options_as_LSTMOptions(); } -template <> -inline const ResizeBilinearOptions * -Operator::builtin_options_as() const { +template<> inline const ResizeBilinearOptions *Operator::builtin_options_as() const { return builtin_options_as_ResizeBilinearOptions(); } -template <> -inline const CallOptions *Operator::builtin_options_as() const { +template<> inline const CallOptions *Operator::builtin_options_as() const { return builtin_options_as_CallOptions(); } -template <> -inline const ReshapeOptions *Operator::builtin_options_as() - const { +template<> inline const ReshapeOptions *Operator::builtin_options_as() const { return builtin_options_as_ReshapeOptions(); } -template <> -inline const SkipGramOptions *Operator::builtin_options_as() - const { +template<> inline const SkipGramOptions *Operator::builtin_options_as() const { return builtin_options_as_SkipGramOptions(); } -template <> -inline const SpaceToDepthOptions * -Operator::builtin_options_as() const { +template<> inline const SpaceToDepthOptions *Operator::builtin_options_as() const { return builtin_options_as_SpaceToDepthOptions(); } -template <> -inline const EmbeddingLookupSparseOptions * -Operator::builtin_options_as() const { +template<> inline const EmbeddingLookupSparseOptions *Operator::builtin_options_as() const { return builtin_options_as_EmbeddingLookupSparseOptions(); } -template <> -inline const MulOptions *Operator::builtin_options_as() const { +template<> inline const MulOptions *Operator::builtin_options_as() const { return builtin_options_as_MulOptions(); } -template <> -inline const PadOptions *Operator::builtin_options_as() const { +template<> inline const PadOptions *Operator::builtin_options_as() const { return builtin_options_as_PadOptions(); } -template <> -inline const GatherOptions *Operator::builtin_options_as() - const { +template<> inline const GatherOptions *Operator::builtin_options_as() const { return builtin_options_as_GatherOptions(); } -template <> -inline const BatchToSpaceNDOptions * -Operator::builtin_options_as() const { +template<> inline const BatchToSpaceNDOptions *Operator::builtin_options_as() const { return builtin_options_as_BatchToSpaceNDOptions(); } -template <> -inline const SpaceToBatchNDOptions * -Operator::builtin_options_as() const { +template<> inline const SpaceToBatchNDOptions *Operator::builtin_options_as() const { return builtin_options_as_SpaceToBatchNDOptions(); } -template <> -inline const TransposeOptions *Operator::builtin_options_as() - const { +template<> inline const TransposeOptions *Operator::builtin_options_as() const { return builtin_options_as_TransposeOptions(); } -template <> -inline const MeanOptions *Operator::builtin_options_as() const { +template<> inline const MeanOptions *Operator::builtin_options_as() const { return builtin_options_as_MeanOptions(); } -template <> -inline const SubOptions *Operator::builtin_options_as() const { +template<> inline const SubOptions *Operator::builtin_options_as() const { return builtin_options_as_SubOptions(); } -template <> -inline const DivOptions *Operator::builtin_options_as() const { +template<> inline const DivOptions *Operator::builtin_options_as() const { return builtin_options_as_DivOptions(); } -template <> -inline const SqueezeOptions *Operator::builtin_options_as() - const { +template<> inline const SqueezeOptions *Operator::builtin_options_as() const { return builtin_options_as_SqueezeOptions(); } -template <> -inline const SequenceRNNOptions * -Operator::builtin_options_as() const { +template<> inline const SequenceRNNOptions *Operator::builtin_options_as() const { return builtin_options_as_SequenceRNNOptions(); } -template <> -inline const StridedSliceOptions * -Operator::builtin_options_as() const { +template<> inline const StridedSliceOptions *Operator::builtin_options_as() const { return builtin_options_as_StridedSliceOptions(); } +template<> inline const ExpOptions *Operator::builtin_options_as() const { + return builtin_options_as_ExpOptions(); +} + +template<> inline const TopKV2Options *Operator::builtin_options_as() const { + return builtin_options_as_TopKV2Options(); +} + +template<> inline const SplitOptions *Operator::builtin_options_as() const { + return builtin_options_as_SplitOptions(); +} + +template<> inline const LogSoftmaxOptions *Operator::builtin_options_as() const { + return builtin_options_as_LogSoftmaxOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4084,21 +4043,19 @@ struct OperatorBuilder { fbb_.AddOffset(Operator::VT_OUTPUTS, outputs); } void add_builtin_options_type(BuiltinOptions builtin_options_type) { - fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, - static_cast(builtin_options_type), 0); + fbb_.AddElement(Operator::VT_BUILTIN_OPTIONS_TYPE, static_cast(builtin_options_type), 0); } void add_builtin_options(flatbuffers::Offset builtin_options) { fbb_.AddOffset(Operator::VT_BUILTIN_OPTIONS, builtin_options); } - void add_custom_options( - flatbuffers::Offset> custom_options) { + void add_custom_options(flatbuffers::Offset> custom_options) { fbb_.AddOffset(Operator::VT_CUSTOM_OPTIONS, custom_options); } void add_custom_options_format(CustomOptionsFormat custom_options_format) { - fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, - static_cast(custom_options_format), 0); + fbb_.AddElement(Operator::VT_CUSTOM_OPTIONS_FORMAT, static_cast(custom_options_format), 0); } - explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit OperatorBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } OperatorBuilder &operator=(const OperatorBuilder &); @@ -4110,14 +4067,14 @@ struct OperatorBuilder { }; inline flatbuffers::Offset CreateOperator( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, flatbuffers::Offset> inputs = 0, flatbuffers::Offset> outputs = 0, BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, flatbuffers::Offset> custom_options = 0, - CustomOptionsFormat custom_options_format = - CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { OperatorBuilder builder_(_fbb); builder_.add_custom_options(custom_options); builder_.add_builtin_options(builtin_options); @@ -4130,25 +4087,26 @@ inline flatbuffers::Offset CreateOperator( } inline flatbuffers::Offset CreateOperatorDirect( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t opcode_index = 0, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t opcode_index = 0, const std::vector *inputs = nullptr, const std::vector *outputs = nullptr, BuiltinOptions builtin_options_type = BuiltinOptions_NONE, flatbuffers::Offset builtin_options = 0, const std::vector *custom_options = nullptr, - CustomOptionsFormat custom_options_format = - CustomOptionsFormat_FLEXBUFFERS) { + CustomOptionsFormat custom_options_format = CustomOptionsFormat_FLEXBUFFERS) { return tflite::CreateOperator( - _fbb, opcode_index, inputs ? _fbb.CreateVector(*inputs) : 0, - outputs ? _fbb.CreateVector(*outputs) : 0, builtin_options_type, + _fbb, + opcode_index, + inputs ? _fbb.CreateVector(*inputs) : 0, + outputs ? _fbb.CreateVector(*outputs) : 0, + builtin_options_type, builtin_options, custom_options ? _fbb.CreateVector(*custom_options) : 0, custom_options_format); } -flatbuffers::Offset CreateOperator( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct SubGraphT : public flatbuffers::NativeTable { typedef SubGraph TableType; @@ -4157,7 +4115,8 @@ struct SubGraphT : public flatbuffers::NativeTable { std::vector outputs; std::vector> operators; std::string name; - SubGraphT() {} + SubGraphT() { + } }; struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -4170,8 +4129,7 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_NAME = 12 }; const flatbuffers::Vector> *tensors() const { - return GetPointer> *>( - VT_TENSORS); + return GetPointer> *>(VT_TENSORS); } const flatbuffers::Vector *inputs() const { return GetPointer *>(VT_INPUTS); @@ -4180,41 +4138,36 @@ struct SubGraph FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { return GetPointer *>(VT_OUTPUTS); } const flatbuffers::Vector> *operators() const { - return GetPointer< - const flatbuffers::Vector> *>( - VT_OPERATORS); + return GetPointer> *>(VT_OPERATORS); } const flatbuffers::String *name() const { return GetPointer(VT_NAME); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_TENSORS) && + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TENSORS) && verifier.Verify(tensors()) && verifier.VerifyVectorOfTables(tensors()) && - VerifyOffset(verifier, VT_INPUTS) && verifier.Verify(inputs()) && - VerifyOffset(verifier, VT_OUTPUTS) && verifier.Verify(outputs()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.Verify(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.Verify(outputs()) && VerifyOffset(verifier, VT_OPERATORS) && verifier.Verify(operators()) && verifier.VerifyVectorOfTables(operators()) && - VerifyOffset(verifier, VT_NAME) && verifier.Verify(name()) && + VerifyOffset(verifier, VT_NAME) && + verifier.Verify(name()) && verifier.EndTable(); } - SubGraphT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo( - SubGraphT *_o, - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + SubGraphT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct SubGraphBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; - void add_tensors( - flatbuffers::Offset>> - tensors) { + void add_tensors(flatbuffers::Offset>> tensors) { fbb_.AddOffset(SubGraph::VT_TENSORS, tensors); } void add_inputs(flatbuffers::Offset> inputs) { @@ -4223,15 +4176,14 @@ struct SubGraphBuilder { void add_outputs(flatbuffers::Offset> outputs) { fbb_.AddOffset(SubGraph::VT_OUTPUTS, outputs); } - void add_operators( - flatbuffers::Offset>> - operators) { + void add_operators(flatbuffers::Offset>> operators) { fbb_.AddOffset(SubGraph::VT_OPERATORS, operators); } void add_name(flatbuffers::Offset name) { fbb_.AddOffset(SubGraph::VT_NAME, name); } - explicit SubGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit SubGraphBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } SubGraphBuilder &operator=(const SubGraphBuilder &); @@ -4244,12 +4196,10 @@ struct SubGraphBuilder { inline flatbuffers::Offset CreateSubGraph( flatbuffers::FlatBufferBuilder &_fbb, - flatbuffers::Offset>> - tensors = 0, + flatbuffers::Offset>> tensors = 0, flatbuffers::Offset> inputs = 0, flatbuffers::Offset> outputs = 0, - flatbuffers::Offset>> - operators = 0, + flatbuffers::Offset>> operators = 0, flatbuffers::Offset name = 0) { SubGraphBuilder builder_(_fbb); builder_.add_name(name); @@ -4272,38 +4222,36 @@ inline flatbuffers::Offset CreateSubGraphDirect( tensors ? _fbb.CreateVector>(*tensors) : 0, inputs ? _fbb.CreateVector(*inputs) : 0, outputs ? _fbb.CreateVector(*outputs) : 0, - operators ? _fbb.CreateVector>(*operators) - : 0, + operators ? _fbb.CreateVector>(*operators) : 0, name ? _fbb.CreateString(name) : 0); } -flatbuffers::Offset CreateSubGraph( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct BufferT : public flatbuffers::NativeTable { typedef Buffer TableType; std::vector data; - BufferT() {} + BufferT() { + } }; struct Buffer FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { typedef BufferT NativeTableType; - enum { VT_DATA = 4 }; + enum { + VT_DATA = 4 + }; const flatbuffers::Vector *data() const { return GetPointer *>(VT_DATA); } bool Verify(flatbuffers::Verifier &verifier) const { - return VerifyTableStart(verifier) && VerifyOffset(verifier, VT_DATA) && - verifier.Verify(data()) && verifier.EndTable(); + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.Verify(data()) && + verifier.EndTable(); } - BufferT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + BufferT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct BufferBuilder { @@ -4312,7 +4260,8 @@ struct BufferBuilder { void add_data(flatbuffers::Offset> data) { fbb_.AddOffset(Buffer::VT_DATA, data); } - explicit BufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit BufferBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } BufferBuilder &operator=(const BufferBuilder &); @@ -4334,13 +4283,12 @@ inline flatbuffers::Offset CreateBuffer( inline flatbuffers::Offset CreateBufferDirect( flatbuffers::FlatBufferBuilder &_fbb, const std::vector *data = nullptr) { - return tflite::CreateBuffer(_fbb, - data ? _fbb.CreateVector(*data) : 0); + return tflite::CreateBuffer( + _fbb, + data ? _fbb.CreateVector(*data) : 0); } -flatbuffers::Offset CreateBuffer( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateBuffer(flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); struct ModelT : public flatbuffers::NativeTable { typedef Model TableType; @@ -4349,7 +4297,9 @@ struct ModelT : public flatbuffers::NativeTable { std::vector> subgraphs; std::string description; std::vector> buffers; - ModelT() : version(0) {} + ModelT() + : version(0) { + } }; struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { @@ -4361,24 +4311,20 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { VT_DESCRIPTION = 10, VT_BUFFERS = 12 }; - uint32_t version() const { return GetField(VT_VERSION, 0); } - const flatbuffers::Vector> *operator_codes() - const { - return GetPointer< - const flatbuffers::Vector> *>( - VT_OPERATOR_CODES); + uint32_t version() const { + return GetField(VT_VERSION, 0); + } + const flatbuffers::Vector> *operator_codes() const { + return GetPointer> *>(VT_OPERATOR_CODES); } const flatbuffers::Vector> *subgraphs() const { - return GetPointer< - const flatbuffers::Vector> *>( - VT_SUBGRAPHS); + return GetPointer> *>(VT_SUBGRAPHS); } const flatbuffers::String *description() const { return GetPointer(VT_DESCRIPTION); } const flatbuffers::Vector> *buffers() const { - return GetPointer> *>( - VT_BUFFERS); + return GetPointer> *>(VT_BUFFERS); } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && @@ -4391,16 +4337,14 @@ struct Model FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { verifier.VerifyVectorOfTables(subgraphs()) && VerifyOffset(verifier, VT_DESCRIPTION) && verifier.Verify(description()) && - VerifyOffset(verifier, VT_BUFFERS) && verifier.Verify(buffers()) && - verifier.VerifyVectorOfTables(buffers()) && verifier.EndTable(); + VerifyOffset(verifier, VT_BUFFERS) && + verifier.Verify(buffers()) && + verifier.VerifyVectorOfTables(buffers()) && + verifier.EndTable(); } - ModelT *UnPack( - const flatbuffers::resolver_function_t *_resolver = nullptr) const; - void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = - nullptr) const; - static flatbuffers::Offset Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); + ModelT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); }; struct ModelBuilder { @@ -4409,26 +4353,20 @@ struct ModelBuilder { void add_version(uint32_t version) { fbb_.AddElement(Model::VT_VERSION, version, 0); } - void add_operator_codes( - flatbuffers::Offset< - flatbuffers::Vector>> - operator_codes) { + void add_operator_codes(flatbuffers::Offset>> operator_codes) { fbb_.AddOffset(Model::VT_OPERATOR_CODES, operator_codes); } - void add_subgraphs( - flatbuffers::Offset>> - subgraphs) { + void add_subgraphs(flatbuffers::Offset>> subgraphs) { fbb_.AddOffset(Model::VT_SUBGRAPHS, subgraphs); } void add_description(flatbuffers::Offset description) { fbb_.AddOffset(Model::VT_DESCRIPTION, description); } - void add_buffers( - flatbuffers::Offset>> - buffers) { + void add_buffers(flatbuffers::Offset>> buffers) { fbb_.AddOffset(Model::VT_BUFFERS, buffers); } - explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { + explicit ModelBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { start_ = fbb_.StartTable(); } ModelBuilder &operator=(const ModelBuilder &); @@ -4440,14 +4378,12 @@ struct ModelBuilder { }; inline flatbuffers::Offset CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, - flatbuffers::Offset>> - operator_codes = 0, - flatbuffers::Offset>> - subgraphs = 0, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + flatbuffers::Offset>> operator_codes = 0, + flatbuffers::Offset>> subgraphs = 0, flatbuffers::Offset description = 0, - flatbuffers::Offset>> - buffers = 0) { + flatbuffers::Offset>> buffers = 0) { ModelBuilder builder_(_fbb); builder_.add_buffers(buffers); builder_.add_description(description); @@ -4458,2010 +4394,1300 @@ inline flatbuffers::Offset CreateModel( } inline flatbuffers::Offset CreateModelDirect( - flatbuffers::FlatBufferBuilder &_fbb, uint32_t version = 0, - const std::vector> *operator_codes = - nullptr, + flatbuffers::FlatBufferBuilder &_fbb, + uint32_t version = 0, + const std::vector> *operator_codes = nullptr, const std::vector> *subgraphs = nullptr, const char *description = nullptr, const std::vector> *buffers = nullptr) { return tflite::CreateModel( - _fbb, version, - operator_codes ? _fbb.CreateVector>( - *operator_codes) - : 0, - subgraphs ? _fbb.CreateVector>(*subgraphs) - : 0, + _fbb, + version, + operator_codes ? _fbb.CreateVector>(*operator_codes) : 0, + subgraphs ? _fbb.CreateVector>(*subgraphs) : 0, description ? _fbb.CreateString(description) : 0, buffers ? _fbb.CreateVector>(*buffers) : 0); } -flatbuffers::Offset CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher = nullptr); +flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); -inline QuantizationParametersT *QuantizationParameters::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline QuantizationParametersT *QuantizationParameters::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new QuantizationParametersT(); UnPackTo(_o, _resolver); return _o; } -inline void QuantizationParameters::UnPackTo( - QuantizationParametersT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void QuantizationParameters::UnPackTo(QuantizationParametersT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = min(); - if (_e) { - _o->min.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->min[_i] = _e->Get(_i); - } - } - }; - { - auto _e = max(); - if (_e) { - _o->max.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->max[_i] = _e->Get(_i); - } - } - }; - { - auto _e = scale(); - if (_e) { - _o->scale.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->scale[_i] = _e->Get(_i); - } - } - }; - { - auto _e = zero_point(); - if (_e) { - _o->zero_point.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->zero_point[_i] = _e->Get(_i); - } - } - }; + { auto _e = min(); if (_e) { _o->min.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->min[_i] = _e->Get(_i); } } }; + { auto _e = max(); if (_e) { _o->max.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->max[_i] = _e->Get(_i); } } }; + { auto _e = scale(); if (_e) { _o->scale.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->scale[_i] = _e->Get(_i); } } }; + { auto _e = zero_point(); if (_e) { _o->zero_point.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->zero_point[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset QuantizationParameters::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset QuantizationParameters::Pack(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateQuantizationParameters(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateQuantizationParameters( - flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateQuantizationParameters(flatbuffers::FlatBufferBuilder &_fbb, const QuantizationParametersT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const QuantizationParametersT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const QuantizationParametersT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _min = _o->min.size() ? _fbb.CreateVector(_o->min) : 0; auto _max = _o->max.size() ? _fbb.CreateVector(_o->max) : 0; auto _scale = _o->scale.size() ? _fbb.CreateVector(_o->scale) : 0; - auto _zero_point = - _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; - return tflite::CreateQuantizationParameters(_fbb, _min, _max, _scale, - _zero_point); + auto _zero_point = _o->zero_point.size() ? _fbb.CreateVector(_o->zero_point) : 0; + return tflite::CreateQuantizationParameters( + _fbb, + _min, + _max, + _scale, + _zero_point); } -inline TensorT *Tensor::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline TensorT *Tensor::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new TensorT(); UnPackTo(_o, _resolver); return _o; } -inline void Tensor::UnPackTo( - TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Tensor::UnPackTo(TensorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = shape(); - if (_e) { - _o->shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->shape[_i] = _e->Get(_i); - } - } - }; - { - auto _e = type(); - _o->type = _e; - }; - { - auto _e = buffer(); - _o->buffer = _e; - }; - { - auto _e = name(); - if (_e) _o->name = _e->str(); - }; - { - auto _e = quantization(); - if (_e) - _o->quantization = - std::unique_ptr(_e->UnPack(_resolver)); - }; + { auto _e = shape(); if (_e) { _o->shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->shape[_i] = _e->Get(_i); } } }; + { auto _e = type(); _o->type = _e; }; + { auto _e = buffer(); _o->buffer = _e; }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; + { auto _e = quantization(); if (_e) _o->quantization = std::unique_ptr(_e->UnPack(_resolver)); }; } -inline flatbuffers::Offset Tensor::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Tensor::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TensorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateTensor(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateTensor( - flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateTensor(flatbuffers::FlatBufferBuilder &_fbb, const TensorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const TensorT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TensorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _shape = _o->shape.size() ? _fbb.CreateVector(_o->shape) : 0; auto _type = _o->type; auto _buffer = _o->buffer; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); - auto _quantization = _o->quantization - ? CreateQuantizationParameters( - _fbb, _o->quantization.get(), _rehasher) - : 0; - return tflite::CreateTensor(_fbb, _shape, _type, _buffer, _name, - _quantization); + auto _quantization = _o->quantization ? CreateQuantizationParameters(_fbb, _o->quantization.get(), _rehasher) : 0; + return tflite::CreateTensor( + _fbb, + _shape, + _type, + _buffer, + _name, + _quantization); } -inline Conv2DOptionsT *Conv2DOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline Conv2DOptionsT *Conv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new Conv2DOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void Conv2DOptions::UnPackTo( - Conv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void Conv2DOptions::UnPackTo(Conv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = padding(); - _o->padding = _e; - }; - { - auto _e = stride_w(); - _o->stride_w = _e; - }; - { - auto _e = stride_h(); - _o->stride_h = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset Conv2DOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Conv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateConv2DOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Conv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const Conv2DOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Conv2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _padding = _o->padding; auto _stride_w = _o->stride_w; auto _stride_h = _o->stride_h; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateConv2DOptions(_fbb, _padding, _stride_w, _stride_h, - _fused_activation_function); + return tflite::CreateConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _fused_activation_function); } -inline Pool2DOptionsT *Pool2DOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline Pool2DOptionsT *Pool2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new Pool2DOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void Pool2DOptions::UnPackTo( - Pool2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void Pool2DOptions::UnPackTo(Pool2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = padding(); - _o->padding = _e; - }; - { - auto _e = stride_w(); - _o->stride_w = _e; - }; - { - auto _e = stride_h(); - _o->stride_h = _e; - }; - { - auto _e = filter_width(); - _o->filter_width = _e; - }; - { - auto _e = filter_height(); - _o->filter_height = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = filter_width(); _o->filter_width = _e; }; + { auto _e = filter_height(); _o->filter_height = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset Pool2DOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Pool2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreatePool2DOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreatePool2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreatePool2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const Pool2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const Pool2DOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const Pool2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _padding = _o->padding; auto _stride_w = _o->stride_w; auto _stride_h = _o->stride_h; auto _filter_width = _o->filter_width; auto _filter_height = _o->filter_height; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreatePool2DOptions(_fbb, _padding, _stride_w, _stride_h, - _filter_width, _filter_height, - _fused_activation_function); + return tflite::CreatePool2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _filter_width, + _filter_height, + _fused_activation_function); } -inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline DepthwiseConv2DOptionsT *DepthwiseConv2DOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new DepthwiseConv2DOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void DepthwiseConv2DOptions::UnPackTo( - DepthwiseConv2DOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void DepthwiseConv2DOptions::UnPackTo(DepthwiseConv2DOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = padding(); - _o->padding = _e; - }; - { - auto _e = stride_w(); - _o->stride_w = _e; - }; - { - auto _e = stride_h(); - _o->stride_h = _e; - }; - { - auto _e = depth_multiplier(); - _o->depth_multiplier = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = padding(); _o->padding = _e; }; + { auto _e = stride_w(); _o->stride_w = _e; }; + { auto _e = stride_h(); _o->stride_h = _e; }; + { auto _e = depth_multiplier(); _o->depth_multiplier = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset DepthwiseConv2DOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset DepthwiseConv2DOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateDepthwiseConv2DOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateDepthwiseConv2DOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateDepthwiseConv2DOptions(flatbuffers::FlatBufferBuilder &_fbb, const DepthwiseConv2DOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const DepthwiseConv2DOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DepthwiseConv2DOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _padding = _o->padding; auto _stride_w = _o->stride_w; auto _stride_h = _o->stride_h; auto _depth_multiplier = _o->depth_multiplier; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateDepthwiseConv2DOptions(_fbb, _padding, _stride_w, - _stride_h, _depth_multiplier, - _fused_activation_function); + return tflite::CreateDepthwiseConv2DOptions( + _fbb, + _padding, + _stride_w, + _stride_h, + _depth_multiplier, + _fused_activation_function); } -inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ConcatEmbeddingsOptionsT *ConcatEmbeddingsOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ConcatEmbeddingsOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ConcatEmbeddingsOptions::UnPackTo( - ConcatEmbeddingsOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ConcatEmbeddingsOptions::UnPackTo(ConcatEmbeddingsOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = num_channels(); - _o->num_channels = _e; - }; - { - auto _e = num_columns_per_channel(); - if (_e) { - _o->num_columns_per_channel.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->num_columns_per_channel[_i] = _e->Get(_i); - } - } - }; - { - auto _e = embedding_dim_per_channel(); - if (_e) { - _o->embedding_dim_per_channel.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->embedding_dim_per_channel[_i] = _e->Get(_i); - } - } - }; + { auto _e = num_channels(); _o->num_channels = _e; }; + { auto _e = num_columns_per_channel(); if (_e) { _o->num_columns_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->num_columns_per_channel[_i] = _e->Get(_i); } } }; + { auto _e = embedding_dim_per_channel(); if (_e) { _o->embedding_dim_per_channel.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->embedding_dim_per_channel[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset -ConcatEmbeddingsOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ConcatEmbeddingsOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateConcatEmbeddingsOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateConcatEmbeddingsOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateConcatEmbeddingsOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatEmbeddingsOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ConcatEmbeddingsOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConcatEmbeddingsOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _num_channels = _o->num_channels; - auto _num_columns_per_channel = - _o->num_columns_per_channel.size() - ? _fbb.CreateVector(_o->num_columns_per_channel) - : 0; - auto _embedding_dim_per_channel = - _o->embedding_dim_per_channel.size() - ? _fbb.CreateVector(_o->embedding_dim_per_channel) - : 0; - return tflite::CreateConcatEmbeddingsOptions(_fbb, _num_channels, - _num_columns_per_channel, - _embedding_dim_per_channel); -} - -inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + auto _num_columns_per_channel = _o->num_columns_per_channel.size() ? _fbb.CreateVector(_o->num_columns_per_channel) : 0; + auto _embedding_dim_per_channel = _o->embedding_dim_per_channel.size() ? _fbb.CreateVector(_o->embedding_dim_per_channel) : 0; + return tflite::CreateConcatEmbeddingsOptions( + _fbb, + _num_channels, + _num_columns_per_channel, + _embedding_dim_per_channel); +} + +inline LSHProjectionOptionsT *LSHProjectionOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LSHProjectionOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void LSHProjectionOptions::UnPackTo( - LSHProjectionOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void LSHProjectionOptions::UnPackTo(LSHProjectionOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = type(); - _o->type = _e; - }; + { auto _e = type(); _o->type = _e; }; } -inline flatbuffers::Offset LSHProjectionOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset LSHProjectionOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateLSHProjectionOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateLSHProjectionOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateLSHProjectionOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSHProjectionOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const LSHProjectionOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LSHProjectionOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _type = _o->type; - return tflite::CreateLSHProjectionOptions(_fbb, _type); + return tflite::CreateLSHProjectionOptions( + _fbb, + _type); } -inline SVDFOptionsT *SVDFOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SVDFOptionsT *SVDFOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SVDFOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SVDFOptions::UnPackTo( - SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void SVDFOptions::UnPackTo(SVDFOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = rank(); - _o->rank = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = rank(); _o->rank = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset SVDFOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SVDFOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSVDFOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSVDFOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSVDFOptions(flatbuffers::FlatBufferBuilder &_fbb, const SVDFOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SVDFOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SVDFOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _rank = _o->rank; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateSVDFOptions(_fbb, _rank, _fused_activation_function); + return tflite::CreateSVDFOptions( + _fbb, + _rank, + _fused_activation_function); } -inline RNNOptionsT *RNNOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline RNNOptionsT *RNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new RNNOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void RNNOptions::UnPackTo( - RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void RNNOptions::UnPackTo(RNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset RNNOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset RNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateRNNOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const RNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const RNNOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const RNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateRNNOptions(_fbb, _fused_activation_function); + return tflite::CreateRNNOptions( + _fbb, + _fused_activation_function); } -inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SequenceRNNOptionsT *SequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SequenceRNNOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SequenceRNNOptions::UnPackTo( - SequenceRNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SequenceRNNOptions::UnPackTo(SequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = time_major(); - _o->time_major = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = time_major(); _o->time_major = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset SequenceRNNOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSequenceRNNOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const SequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SequenceRNNOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateSequenceRNNOptions(_fbb, _time_major, - _fused_activation_function); + return tflite::CreateSequenceRNNOptions( + _fbb, + _time_major, + _fused_activation_function); } -inline BidirectionalSequenceRNNOptionsT * -BidirectionalSequenceRNNOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline BidirectionalSequenceRNNOptionsT *BidirectionalSequenceRNNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BidirectionalSequenceRNNOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void BidirectionalSequenceRNNOptions::UnPackTo( - BidirectionalSequenceRNNOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = time_major(); - _o->time_major = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = time_major(); _o->time_major = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset -BidirectionalSequenceRNNOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const BidirectionalSequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateBidirectionalSequenceRNNOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateBidirectionalSequenceRNNOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const BidirectionalSequenceRNNOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateBidirectionalSequenceRNNOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const BidirectionalSequenceRNNOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; return tflite::CreateBidirectionalSequenceRNNOptions( - _fbb, _time_major, _fused_activation_function); + _fbb, + _time_major, + _fused_activation_function); } -inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new FullyConnectedOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void FullyConnectedOptions::UnPackTo( - FullyConnectedOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void FullyConnectedOptions::UnPackTo(FullyConnectedOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset FullyConnectedOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset FullyConnectedOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateFullyConnectedOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateFullyConnectedOptions( - flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateFullyConnectedOptions(flatbuffers::FlatBufferBuilder &_fbb, const FullyConnectedOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const FullyConnectedOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const FullyConnectedOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateFullyConnectedOptions(_fbb, _fused_activation_function); + return tflite::CreateFullyConnectedOptions( + _fbb, + _fused_activation_function); } -inline SoftmaxOptionsT *SoftmaxOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SoftmaxOptionsT *SoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SoftmaxOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SoftmaxOptions::UnPackTo( - SoftmaxOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SoftmaxOptions::UnPackTo(SoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = beta(); - _o->beta = _e; - }; + { auto _e = beta(); _o->beta = _e; }; } -inline flatbuffers::Offset SoftmaxOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SoftmaxOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSoftmaxOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSoftmaxOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const SoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SoftmaxOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SoftmaxOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _beta = _o->beta; - return tflite::CreateSoftmaxOptions(_fbb, _beta); + return tflite::CreateSoftmaxOptions( + _fbb, + _beta); } -inline ConcatenationOptionsT *ConcatenationOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ConcatenationOptionsT *ConcatenationOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ConcatenationOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ConcatenationOptions::UnPackTo( - ConcatenationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ConcatenationOptions::UnPackTo(ConcatenationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = axis(); - _o->axis = _e; - }; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = axis(); _o->axis = _e; }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset ConcatenationOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ConcatenationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateConcatenationOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateConcatenationOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateConcatenationOptions(flatbuffers::FlatBufferBuilder &_fbb, const ConcatenationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ConcatenationOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ConcatenationOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _axis = _o->axis; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateConcatenationOptions(_fbb, _axis, - _fused_activation_function); + return tflite::CreateConcatenationOptions( + _fbb, + _axis, + _fused_activation_function); } -inline AddOptionsT *AddOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline AddOptionsT *AddOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new AddOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void AddOptions::UnPackTo( - AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void AddOptions::UnPackTo(AddOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset AddOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset AddOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateAddOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateAddOptions( - flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateAddOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const AddOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AddOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateAddOptions(_fbb, _fused_activation_function); + return tflite::CreateAddOptions( + _fbb, + _fused_activation_function); } -inline MulOptionsT *MulOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline MulOptionsT *MulOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new MulOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MulOptions::UnPackTo( - MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void MulOptions::UnPackTo(MulOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset MulOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset MulOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateMulOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMulOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateMulOptions(flatbuffers::FlatBufferBuilder &_fbb, const MulOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const MulOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MulOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateMulOptions(_fbb, _fused_activation_function); + return tflite::CreateMulOptions( + _fbb, + _fused_activation_function); } -inline L2NormOptionsT *L2NormOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline L2NormOptionsT *L2NormOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new L2NormOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void L2NormOptions::UnPackTo( - L2NormOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void L2NormOptions::UnPackTo(L2NormOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset L2NormOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset L2NormOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateL2NormOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateL2NormOptions( - flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateL2NormOptions(flatbuffers::FlatBufferBuilder &_fbb, const L2NormOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const L2NormOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const L2NormOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateL2NormOptions(_fbb, _fused_activation_function); + return tflite::CreateL2NormOptions( + _fbb, + _fused_activation_function); } -inline LocalResponseNormalizationOptionsT * -LocalResponseNormalizationOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline LocalResponseNormalizationOptionsT *LocalResponseNormalizationOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LocalResponseNormalizationOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void LocalResponseNormalizationOptions::UnPackTo( - LocalResponseNormalizationOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void LocalResponseNormalizationOptions::UnPackTo(LocalResponseNormalizationOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = radius(); - _o->radius = _e; - }; - { - auto _e = bias(); - _o->bias = _e; - }; - { - auto _e = alpha(); - _o->alpha = _e; - }; - { - auto _e = beta(); - _o->beta = _e; - }; + { auto _e = radius(); _o->radius = _e; }; + { auto _e = bias(); _o->bias = _e; }; + { auto _e = alpha(); _o->alpha = _e; }; + { auto _e = beta(); _o->beta = _e; }; } -inline flatbuffers::Offset -LocalResponseNormalizationOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset LocalResponseNormalizationOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateLocalResponseNormalizationOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateLocalResponseNormalizationOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const LocalResponseNormalizationOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateLocalResponseNormalizationOptions(flatbuffers::FlatBufferBuilder &_fbb, const LocalResponseNormalizationOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const LocalResponseNormalizationOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LocalResponseNormalizationOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _radius = _o->radius; auto _bias = _o->bias; auto _alpha = _o->alpha; auto _beta = _o->beta; - return tflite::CreateLocalResponseNormalizationOptions(_fbb, _radius, _bias, - _alpha, _beta); + return tflite::CreateLocalResponseNormalizationOptions( + _fbb, + _radius, + _bias, + _alpha, + _beta); } -inline LSTMOptionsT *LSTMOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline LSTMOptionsT *LSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new LSTMOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void LSTMOptions::UnPackTo( - LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void LSTMOptions::UnPackTo(LSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; - { - auto _e = cell_clip(); - _o->cell_clip = _e; - }; - { - auto _e = proj_clip(); - _o->proj_clip = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = cell_clip(); _o->cell_clip = _e; }; + { auto _e = proj_clip(); _o->proj_clip = _e; }; } -inline flatbuffers::Offset LSTMOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset LSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateLSTMOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateLSTMOptions( - flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const LSTMOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; auto _cell_clip = _o->cell_clip; auto _proj_clip = _o->proj_clip; - return tflite::CreateLSTMOptions(_fbb, _fused_activation_function, _cell_clip, - _proj_clip); + return tflite::CreateLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip); } -inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ResizeBilinearOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ResizeBilinearOptions::UnPackTo( - ResizeBilinearOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ResizeBilinearOptions::UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; + { auto _e = align_corners(); _o->align_corners = _e; }; } -inline flatbuffers::Offset ResizeBilinearOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ResizeBilinearOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateResizeBilinearOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateResizeBilinearOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ResizeBilinearOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - return tflite::CreateResizeBilinearOptions(_fbb); + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ResizeBilinearOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _align_corners = _o->align_corners; + return tflite::CreateResizeBilinearOptions( + _fbb, + _align_corners); } -inline CallOptionsT *CallOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline CallOptionsT *CallOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new CallOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void CallOptions::UnPackTo( - CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void CallOptions::UnPackTo(CallOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = subgraph(); - _o->subgraph = _e; - }; + { auto _e = subgraph(); _o->subgraph = _e; }; } -inline flatbuffers::Offset CallOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CallOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateCallOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateCallOptions( - flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateCallOptions(flatbuffers::FlatBufferBuilder &_fbb, const CallOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const CallOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const CallOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _subgraph = _o->subgraph; - return tflite::CreateCallOptions(_fbb, _subgraph); + return tflite::CreateCallOptions( + _fbb, + _subgraph); } -inline PadOptionsT *PadOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline PadOptionsT *PadOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new PadOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void PadOptions::UnPackTo( - PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void PadOptions::UnPackTo(PadOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; } -inline flatbuffers::Offset PadOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset PadOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreatePadOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreatePadOptions( - flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreatePadOptions(flatbuffers::FlatBufferBuilder &_fbb, const PadOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const PadOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - return tflite::CreatePadOptions(_fbb); + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const PadOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreatePadOptions( + _fbb); } -inline ReshapeOptionsT *ReshapeOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ReshapeOptionsT *ReshapeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ReshapeOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void ReshapeOptions::UnPackTo( - ReshapeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void ReshapeOptions::UnPackTo(ReshapeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = new_shape(); - if (_e) { - _o->new_shape.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->new_shape[_i] = _e->Get(_i); - } - } - }; + { auto _e = new_shape(); if (_e) { _o->new_shape.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->new_shape[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset ReshapeOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset ReshapeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateReshapeOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateReshapeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateReshapeOptions(flatbuffers::FlatBufferBuilder &_fbb, const ReshapeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ReshapeOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ReshapeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _new_shape = _o->new_shape.size() ? _fbb.CreateVector(_o->new_shape) : 0; - return tflite::CreateReshapeOptions(_fbb, _new_shape); + return tflite::CreateReshapeOptions( + _fbb, + _new_shape); } -inline SpaceToBatchNDOptionsT *SpaceToBatchNDOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SpaceToBatchNDOptionsT *SpaceToBatchNDOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SpaceToBatchNDOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SpaceToBatchNDOptions::UnPackTo( - SpaceToBatchNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SpaceToBatchNDOptions::UnPackTo(SpaceToBatchNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; } -inline flatbuffers::Offset SpaceToBatchNDOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SpaceToBatchNDOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSpaceToBatchNDOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSpaceToBatchNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSpaceToBatchNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToBatchNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SpaceToBatchNDOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - return tflite::CreateSpaceToBatchNDOptions(_fbb); + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SpaceToBatchNDOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateSpaceToBatchNDOptions( + _fbb); } -inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline BatchToSpaceNDOptionsT *BatchToSpaceNDOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BatchToSpaceNDOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void BatchToSpaceNDOptions::UnPackTo( - BatchToSpaceNDOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void BatchToSpaceNDOptions::UnPackTo(BatchToSpaceNDOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; } -inline flatbuffers::Offset BatchToSpaceNDOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset BatchToSpaceNDOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateBatchToSpaceNDOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateBatchToSpaceNDOptions( - flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateBatchToSpaceNDOptions(flatbuffers::FlatBufferBuilder &_fbb, const BatchToSpaceNDOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const BatchToSpaceNDOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - return tflite::CreateBatchToSpaceNDOptions(_fbb); + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BatchToSpaceNDOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateBatchToSpaceNDOptions( + _fbb); } -inline SkipGramOptionsT *SkipGramOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SkipGramOptionsT *SkipGramOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SkipGramOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SkipGramOptions::UnPackTo( - SkipGramOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SkipGramOptions::UnPackTo(SkipGramOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = ngram_size(); - _o->ngram_size = _e; - }; - { - auto _e = max_skip_size(); - _o->max_skip_size = _e; - }; - { - auto _e = include_all_ngrams(); - _o->include_all_ngrams = _e; - }; + { auto _e = ngram_size(); _o->ngram_size = _e; }; + { auto _e = max_skip_size(); _o->max_skip_size = _e; }; + { auto _e = include_all_ngrams(); _o->include_all_ngrams = _e; }; } -inline flatbuffers::Offset SkipGramOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SkipGramOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSkipGramOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSkipGramOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSkipGramOptions(flatbuffers::FlatBufferBuilder &_fbb, const SkipGramOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SkipGramOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SkipGramOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _ngram_size = _o->ngram_size; auto _max_skip_size = _o->max_skip_size; auto _include_all_ngrams = _o->include_all_ngrams; - return tflite::CreateSkipGramOptions(_fbb, _ngram_size, _max_skip_size, - _include_all_ngrams); + return tflite::CreateSkipGramOptions( + _fbb, + _ngram_size, + _max_skip_size, + _include_all_ngrams); } -inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SpaceToDepthOptionsT *SpaceToDepthOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SpaceToDepthOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SpaceToDepthOptions::UnPackTo( - SpaceToDepthOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SpaceToDepthOptions::UnPackTo(SpaceToDepthOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = block_size(); - _o->block_size = _e; - }; + { auto _e = block_size(); _o->block_size = _e; }; } -inline flatbuffers::Offset SpaceToDepthOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SpaceToDepthOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSpaceToDepthOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSpaceToDepthOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSpaceToDepthOptions(flatbuffers::FlatBufferBuilder &_fbb, const SpaceToDepthOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SpaceToDepthOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SpaceToDepthOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _block_size = _o->block_size; - return tflite::CreateSpaceToDepthOptions(_fbb, _block_size); + return tflite::CreateSpaceToDepthOptions( + _fbb, + _block_size); } -inline SubOptionsT *SubOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SubOptionsT *SubOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SubOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SubOptions::UnPackTo( - SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void SubOptions::UnPackTo(SubOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset SubOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SubOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSubOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSubOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSubOptions(flatbuffers::FlatBufferBuilder &_fbb, const SubOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SubOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SubOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateSubOptions(_fbb, _fused_activation_function); + return tflite::CreateSubOptions( + _fbb, + _fused_activation_function); } -inline DivOptionsT *DivOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline DivOptionsT *DivOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new DivOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void DivOptions::UnPackTo( - DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void DivOptions::UnPackTo(DivOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = fused_activation_function(); - _o->fused_activation_function = _e; - }; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; } -inline flatbuffers::Offset DivOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset DivOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateDivOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateDivOptions( - flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateDivOptions(flatbuffers::FlatBufferBuilder &_fbb, const DivOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const DivOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DivOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _fused_activation_function = _o->fused_activation_function; - return tflite::CreateDivOptions(_fbb, _fused_activation_function); + return tflite::CreateDivOptions( + _fbb, + _fused_activation_function); +} + +inline TopKV2OptionsT *TopKV2Options::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new TopKV2OptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void TopKV2Options::UnPackTo(TopKV2OptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset TopKV2Options::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateTopKV2Options(_fbb, _o, _rehasher); } -inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline flatbuffers::Offset CreateTopKV2Options(flatbuffers::FlatBufferBuilder &_fbb, const TopKV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TopKV2OptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTopKV2Options( + _fbb); +} + +inline EmbeddingLookupSparseOptionsT *EmbeddingLookupSparseOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new EmbeddingLookupSparseOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void EmbeddingLookupSparseOptions::UnPackTo( - EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void EmbeddingLookupSparseOptions::UnPackTo(EmbeddingLookupSparseOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = combiner(); - _o->combiner = _e; - }; + { auto _e = combiner(); _o->combiner = _e; }; } -inline flatbuffers::Offset -EmbeddingLookupSparseOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset EmbeddingLookupSparseOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateEmbeddingLookupSparseOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset -CreateEmbeddingLookupSparseOptions( - flatbuffers::FlatBufferBuilder &_fbb, - const EmbeddingLookupSparseOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateEmbeddingLookupSparseOptions(flatbuffers::FlatBufferBuilder &_fbb, const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const EmbeddingLookupSparseOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const EmbeddingLookupSparseOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _combiner = _o->combiner; - return tflite::CreateEmbeddingLookupSparseOptions(_fbb, _combiner); + return tflite::CreateEmbeddingLookupSparseOptions( + _fbb, + _combiner); } -inline GatherOptionsT *GatherOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline GatherOptionsT *GatherOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new GatherOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void GatherOptions::UnPackTo( - GatherOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void GatherOptions::UnPackTo(GatherOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = axis(); - _o->axis = _e; - }; + { auto _e = axis(); _o->axis = _e; }; } -inline flatbuffers::Offset GatherOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset GatherOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateGatherOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateGatherOptions( - flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateGatherOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const GatherOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GatherOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _axis = _o->axis; - return tflite::CreateGatherOptions(_fbb, _axis); + return tflite::CreateGatherOptions( + _fbb, + _axis); } -inline TransposeOptionsT *TransposeOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline TransposeOptionsT *TransposeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new TransposeOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void TransposeOptions::UnPackTo( - TransposeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void TransposeOptions::UnPackTo(TransposeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; } -inline flatbuffers::Offset TransposeOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset TransposeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateTransposeOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateTransposeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateTransposeOptions(flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const TransposeOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - return tflite::CreateTransposeOptions(_fbb); + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const TransposeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateTransposeOptions( + _fbb); } -inline MeanOptionsT *MeanOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ExpOptionsT *ExpOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new ExpOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void ExpOptions::UnPackTo(ExpOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset ExpOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateExpOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateExpOptions(flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ExpOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateExpOptions( + _fbb); +} + +inline MeanOptionsT *MeanOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new MeanOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void MeanOptions::UnPackTo( - MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void MeanOptions::UnPackTo(MeanOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = keep_dims(); - _o->keep_dims = _e; - }; + { auto _e = keep_dims(); _o->keep_dims = _e; }; } -inline flatbuffers::Offset MeanOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset MeanOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateMeanOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateMeanOptions( - flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateMeanOptions(flatbuffers::FlatBufferBuilder &_fbb, const MeanOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const MeanOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const MeanOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _keep_dims = _o->keep_dims; - return tflite::CreateMeanOptions(_fbb, _keep_dims); + return tflite::CreateMeanOptions( + _fbb, + _keep_dims); } -inline SqueezeOptionsT *SqueezeOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SqueezeOptionsT *SqueezeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SqueezeOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void SqueezeOptions::UnPackTo( - SqueezeOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void SqueezeOptions::UnPackTo(SqueezeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = squeeze_dims(); - if (_e) { - _o->squeeze_dims.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->squeeze_dims[_i] = _e->Get(_i); - } - } - }; + { auto _e = squeeze_dims(); if (_e) { _o->squeeze_dims.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->squeeze_dims[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset SqueezeOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SqueezeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSqueezeOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSqueezeOptions( - flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSqueezeOptions(flatbuffers::FlatBufferBuilder &_fbb, const SqueezeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SqueezeOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _squeeze_dims = - _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0; - return tflite::CreateSqueezeOptions(_fbb, _squeeze_dims); -} - -inline StridedSliceOptionsT *StridedSliceOptions::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SqueezeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _squeeze_dims = _o->squeeze_dims.size() ? _fbb.CreateVector(_o->squeeze_dims) : 0; + return tflite::CreateSqueezeOptions( + _fbb, + _squeeze_dims); +} + +inline SplitOptionsT *SplitOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new SplitOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void SplitOptions::UnPackTo(SplitOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = num_splits(); _o->num_splits = _e; }; +} + +inline flatbuffers::Offset SplitOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateSplitOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateSplitOptions(flatbuffers::FlatBufferBuilder &_fbb, const SplitOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SplitOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _num_splits = _o->num_splits; + return tflite::CreateSplitOptions( + _fbb, + _num_splits); +} + +inline StridedSliceOptionsT *StridedSliceOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new StridedSliceOptionsT(); UnPackTo(_o, _resolver); return _o; } -inline void StridedSliceOptions::UnPackTo( - StridedSliceOptionsT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void StridedSliceOptions::UnPackTo(StridedSliceOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = begin_mask(); - _o->begin_mask = _e; - }; - { - auto _e = end_mask(); - _o->end_mask = _e; - }; - { - auto _e = ellipsis_mask(); - _o->ellipsis_mask = _e; - }; - { - auto _e = new_axis_mask(); - _o->new_axis_mask = _e; - }; - { - auto _e = shrink_axis_mask(); - _o->shrink_axis_mask = _e; - }; + { auto _e = begin_mask(); _o->begin_mask = _e; }; + { auto _e = end_mask(); _o->end_mask = _e; }; + { auto _e = ellipsis_mask(); _o->ellipsis_mask = _e; }; + { auto _e = new_axis_mask(); _o->new_axis_mask = _e; }; + { auto _e = shrink_axis_mask(); _o->shrink_axis_mask = _e; }; } -inline flatbuffers::Offset StridedSliceOptions::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset StridedSliceOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateStridedSliceOptions(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateStridedSliceOptions( - flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateStridedSliceOptions(flatbuffers::FlatBufferBuilder &_fbb, const StridedSliceOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const StridedSliceOptionsT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const StridedSliceOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _begin_mask = _o->begin_mask; auto _end_mask = _o->end_mask; auto _ellipsis_mask = _o->ellipsis_mask; auto _new_axis_mask = _o->new_axis_mask; auto _shrink_axis_mask = _o->shrink_axis_mask; - return tflite::CreateStridedSliceOptions(_fbb, _begin_mask, _end_mask, - _ellipsis_mask, _new_axis_mask, - _shrink_axis_mask); + return tflite::CreateStridedSliceOptions( + _fbb, + _begin_mask, + _end_mask, + _ellipsis_mask, + _new_axis_mask, + _shrink_axis_mask); } -inline OperatorCodeT *OperatorCode::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline LogSoftmaxOptionsT *LogSoftmaxOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new LogSoftmaxOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void LogSoftmaxOptions::UnPackTo(LogSoftmaxOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset LogSoftmaxOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateLogSoftmaxOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateLogSoftmaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const LogSoftmaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LogSoftmaxOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateLogSoftmaxOptions( + _fbb); +} + +inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); return _o; } -inline void OperatorCode::UnPackTo( - OperatorCodeT *_o, - const flatbuffers::resolver_function_t *_resolver) const { +inline void OperatorCode::UnPackTo(OperatorCodeT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = builtin_code(); - _o->builtin_code = _e; - }; - { - auto _e = custom_code(); - if (_e) _o->custom_code = _e->str(); - }; + { auto _e = builtin_code(); _o->builtin_code = _e; }; + { auto _e = custom_code(); if (_e) _o->custom_code = _e->str(); }; } -inline flatbuffers::Offset OperatorCode::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset OperatorCode::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateOperatorCode(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateOperatorCode( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateOperatorCode(flatbuffers::FlatBufferBuilder &_fbb, const OperatorCodeT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const OperatorCodeT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorCodeT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _builtin_code = _o->builtin_code; - auto _custom_code = - _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); - return tflite::CreateOperatorCode(_fbb, _builtin_code, _custom_code); + auto _custom_code = _o->custom_code.empty() ? 0 : _fbb.CreateString(_o->custom_code); + return tflite::CreateOperatorCode( + _fbb, + _builtin_code, + _custom_code); } -inline OperatorT *Operator::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline OperatorT *Operator::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorT(); UnPackTo(_o, _resolver); return _o; } -inline void Operator::UnPackTo( - OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Operator::UnPackTo(OperatorT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = opcode_index(); - _o->opcode_index = _e; - }; - { - auto _e = inputs(); - if (_e) { - _o->inputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->inputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = outputs(); - if (_e) { - _o->outputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->outputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = builtin_options_type(); - _o->builtin_options.type = _e; - }; - { - auto _e = builtin_options(); - if (_e) - _o->builtin_options.value = - BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); - }; - { - auto _e = custom_options(); - if (_e) { - _o->custom_options.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->custom_options[_i] = _e->Get(_i); - } - } - }; - { - auto _e = custom_options_format(); - _o->custom_options_format = _e; - }; + { auto _e = opcode_index(); _o->opcode_index = _e; }; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; + { auto _e = builtin_options_type(); _o->builtin_options.type = _e; }; + { auto _e = builtin_options(); if (_e) _o->builtin_options.value = BuiltinOptionsUnion::UnPack(_e, builtin_options_type(), _resolver); }; + { auto _e = custom_options(); if (_e) { _o->custom_options.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->custom_options[_i] = _e->Get(_i); } } }; + { auto _e = custom_options_format(); _o->custom_options_format = _e; }; } -inline flatbuffers::Offset Operator::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Operator::Pack(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateOperator(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateOperator( - flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateOperator(flatbuffers::FlatBufferBuilder &_fbb, const OperatorT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const OperatorT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const OperatorT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _opcode_index = _o->opcode_index; auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; auto _builtin_options_type = _o->builtin_options.type; auto _builtin_options = _o->builtin_options.Pack(_fbb); - auto _custom_options = - _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; + auto _custom_options = _o->custom_options.size() ? _fbb.CreateVector(_o->custom_options) : 0; auto _custom_options_format = _o->custom_options_format; - return tflite::CreateOperator(_fbb, _opcode_index, _inputs, _outputs, - _builtin_options_type, _builtin_options, - _custom_options, _custom_options_format); + return tflite::CreateOperator( + _fbb, + _opcode_index, + _inputs, + _outputs, + _builtin_options_type, + _builtin_options, + _custom_options, + _custom_options_format); } -inline SubGraphT *SubGraph::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline SubGraphT *SubGraph::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new SubGraphT(); UnPackTo(_o, _resolver); return _o; } -inline void SubGraph::UnPackTo( - SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void SubGraph::UnPackTo(SubGraphT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = tensors(); - if (_e) { - _o->tensors.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->tensors[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = inputs(); - if (_e) { - _o->inputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->inputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = outputs(); - if (_e) { - _o->outputs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->outputs[_i] = _e->Get(_i); - } - } - }; - { - auto _e = operators(); - if (_e) { - _o->operators.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->operators[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = name(); - if (_e) _o->name = _e->str(); - }; + { auto _e = tensors(); if (_e) { _o->tensors.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->tensors[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = inputs(); if (_e) { _o->inputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->inputs[_i] = _e->Get(_i); } } }; + { auto _e = outputs(); if (_e) { _o->outputs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->outputs[_i] = _e->Get(_i); } } }; + { auto _e = operators(); if (_e) { _o->operators.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operators[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = name(); if (_e) _o->name = _e->str(); }; } -inline flatbuffers::Offset SubGraph::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset SubGraph::Pack(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateSubGraph(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateSubGraph( - flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateSubGraph(flatbuffers::FlatBufferBuilder &_fbb, const SubGraphT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const SubGraphT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; - auto _tensors = - _o->tensors.size() - ? _fbb.CreateVector>( - _o->tensors.size(), - [](size_t i, _VectorArgs *__va) { - return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), - __va->__rehasher); - }, - &_va) - : 0; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const SubGraphT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _tensors = _o->tensors.size() ? _fbb.CreateVector> (_o->tensors.size(), [](size_t i, _VectorArgs *__va) { return CreateTensor(*__va->__fbb, __va->__o->tensors[i].get(), __va->__rehasher); }, &_va ) : 0; auto _inputs = _o->inputs.size() ? _fbb.CreateVector(_o->inputs) : 0; auto _outputs = _o->outputs.size() ? _fbb.CreateVector(_o->outputs) : 0; - auto _operators = _o->operators.size() - ? _fbb.CreateVector>( - _o->operators.size(), - [](size_t i, _VectorArgs *__va) { - return CreateOperator( - *__va->__fbb, __va->__o->operators[i].get(), - __va->__rehasher); - }, - &_va) - : 0; + auto _operators = _o->operators.size() ? _fbb.CreateVector> (_o->operators.size(), [](size_t i, _VectorArgs *__va) { return CreateOperator(*__va->__fbb, __va->__o->operators[i].get(), __va->__rehasher); }, &_va ) : 0; auto _name = _o->name.empty() ? 0 : _fbb.CreateString(_o->name); - return tflite::CreateSubGraph(_fbb, _tensors, _inputs, _outputs, _operators, - _name); + return tflite::CreateSubGraph( + _fbb, + _tensors, + _inputs, + _outputs, + _operators, + _name); } -inline BufferT *Buffer::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline BufferT *Buffer::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new BufferT(); UnPackTo(_o, _resolver); return _o; } -inline void Buffer::UnPackTo( - BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Buffer::UnPackTo(BufferT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = data(); - if (_e) { - _o->data.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->data[_i] = _e->Get(_i); - } - } - }; + { auto _e = data(); if (_e) { _o->data.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->data[_i] = _e->Get(_i); } } }; } -inline flatbuffers::Offset Buffer::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Buffer::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BufferT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateBuffer(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateBuffer( - flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateBuffer(flatbuffers::FlatBufferBuilder &_fbb, const BufferT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const BufferT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BufferT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _data = _o->data.size() ? _fbb.CreateVector(_o->data) : 0; - return tflite::CreateBuffer(_fbb, _data); + return tflite::CreateBuffer( + _fbb, + _data); } -inline ModelT *Model::UnPack( - const flatbuffers::resolver_function_t *_resolver) const { +inline ModelT *Model::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ModelT(); UnPackTo(_o, _resolver); return _o; } -inline void Model::UnPackTo( - ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { +inline void Model::UnPackTo(ModelT *_o, const flatbuffers::resolver_function_t *_resolver) const { (void)_o; (void)_resolver; - { - auto _e = version(); - _o->version = _e; - }; - { - auto _e = operator_codes(); - if (_e) { - _o->operator_codes.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->operator_codes[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = subgraphs(); - if (_e) { - _o->subgraphs.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->subgraphs[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; - { - auto _e = description(); - if (_e) _o->description = _e->str(); - }; - { - auto _e = buffers(); - if (_e) { - _o->buffers.resize(_e->size()); - for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { - _o->buffers[_i] = - std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); - } - } - }; + { auto _e = version(); _o->version = _e; }; + { auto _e = operator_codes(); if (_e) { _o->operator_codes.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->operator_codes[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = subgraphs(); if (_e) { _o->subgraphs.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->subgraphs[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; + { auto _e = description(); if (_e) _o->description = _e->str(); }; + { auto _e = buffers(); if (_e) { _o->buffers.resize(_e->size()); for (flatbuffers::uoffset_t _i = 0; _i < _e->size(); _i++) { _o->buffers[_i] = std::unique_ptr(_e->Get(_i)->UnPack(_resolver)); } } }; } -inline flatbuffers::Offset Model::Pack( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset Model::Pack(flatbuffers::FlatBufferBuilder &_fbb, const ModelT* _o, const flatbuffers::rehasher_function_t *_rehasher) { return CreateModel(_fbb, _o, _rehasher); } -inline flatbuffers::Offset CreateModel( - flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, - const flatbuffers::rehasher_function_t *_rehasher) { +inline flatbuffers::Offset CreateModel(flatbuffers::FlatBufferBuilder &_fbb, const ModelT *_o, const flatbuffers::rehasher_function_t *_rehasher) { (void)_rehasher; (void)_o; - struct _VectorArgs { - flatbuffers::FlatBufferBuilder *__fbb; - const ModelT *__o; - const flatbuffers::rehasher_function_t *__rehasher; - } _va = {&_fbb, _o, _rehasher}; - (void)_va; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const ModelT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _version = _o->version; - auto _operator_codes = - _o->operator_codes.size() - ? _fbb.CreateVector>( - _o->operator_codes.size(), - [](size_t i, _VectorArgs *__va) { - return CreateOperatorCode(*__va->__fbb, - __va->__o->operator_codes[i].get(), - __va->__rehasher); - }, - &_va) - : 0; - auto _subgraphs = _o->subgraphs.size() - ? _fbb.CreateVector>( - _o->subgraphs.size(), - [](size_t i, _VectorArgs *__va) { - return CreateSubGraph( - *__va->__fbb, __va->__o->subgraphs[i].get(), - __va->__rehasher); - }, - &_va) - : 0; - auto _description = - _o->description.empty() ? 0 : _fbb.CreateString(_o->description); - auto _buffers = - _o->buffers.size() - ? _fbb.CreateVector>( - _o->buffers.size(), - [](size_t i, _VectorArgs *__va) { - return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), - __va->__rehasher); - }, - &_va) - : 0; - return tflite::CreateModel(_fbb, _version, _operator_codes, _subgraphs, - _description, _buffers); -} - -inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, - const void *obj, BuiltinOptions type) { + auto _operator_codes = _o->operator_codes.size() ? _fbb.CreateVector> (_o->operator_codes.size(), [](size_t i, _VectorArgs *__va) { return CreateOperatorCode(*__va->__fbb, __va->__o->operator_codes[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _subgraphs = _o->subgraphs.size() ? _fbb.CreateVector> (_o->subgraphs.size(), [](size_t i, _VectorArgs *__va) { return CreateSubGraph(*__va->__fbb, __va->__o->subgraphs[i].get(), __va->__rehasher); }, &_va ) : 0; + auto _description = _o->description.empty() ? 0 : _fbb.CreateString(_o->description); + auto _buffers = _o->buffers.size() ? _fbb.CreateVector> (_o->buffers.size(), [](size_t i, _VectorArgs *__va) { return CreateBuffer(*__va->__fbb, __va->__o->buffers[i].get(), __va->__rehasher); }, &_va ) : 0; + return tflite::CreateModel( + _fbb, + _version, + _operator_codes, + _subgraphs, + _description, + _buffers); +} + +inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type) { switch (type) { case BuiltinOptions_NONE: { return true; @@ -6515,8 +5741,7 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, return verifier.VerifyTable(ptr); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = - reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } case BuiltinOptions_LSTMOptions: { @@ -6595,28 +5820,39 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } - default: - return false; + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return false; } } -inline bool VerifyBuiltinOptionsVector( - flatbuffers::Verifier &verifier, - const flatbuffers::Vector> *values, - const flatbuffers::Vector *types) { +inline bool VerifyBuiltinOptionsVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; if (values->size() != types->size()) return false; for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { - if (!VerifyBuiltinOptions(verifier, values->Get(i), - types->GetEnum(i))) { + if (!VerifyBuiltinOptions( + verifier, values->Get(i), types->GetEnum(i))) { return false; } } return true; } -inline void *BuiltinOptionsUnion::UnPack( - const void *obj, BuiltinOptions type, - const flatbuffers::resolver_function_t *resolver) { +inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, const flatbuffers::resolver_function_t *resolver) { switch (type) { case BuiltinOptions_Conv2DOptions: { auto ptr = reinterpret_cast(obj); @@ -6667,8 +5903,7 @@ inline void *BuiltinOptionsUnion::UnPack( return ptr->UnPack(resolver); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = - reinterpret_cast(obj); + auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } case BuiltinOptions_LSTMOptions: { @@ -6747,14 +5982,27 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } - default: - return nullptr; + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } + default: return nullptr; } } -inline flatbuffers::Offset BuiltinOptionsUnion::Pack( - flatbuffers::FlatBufferBuilder &_fbb, - const flatbuffers::rehasher_function_t *_rehasher) const { +inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBufferBuilder &_fbb, const flatbuffers::rehasher_function_t *_rehasher) const { switch (type) { case BuiltinOptions_Conv2DOptions: { auto ptr = reinterpret_cast(value); @@ -6805,10 +6053,8 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack( return CreateL2NormOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LocalResponseNormalizationOptions: { - auto ptr = - reinterpret_cast(value); - return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher) - .Union(); + auto ptr = reinterpret_cast(value); + return CreateLocalResponseNormalizationOptions(_fbb, ptr, _rehasher).Union(); } case BuiltinOptions_LSTMOptions: { auto ptr = reinterpret_cast(value); @@ -6886,32 +6132,42 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast(value); return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union(); } - default: - return 0; + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(value); + return CreateExpOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(value); + return CreateTopKV2Options(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(value); + return CreateSplitOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(value); + return CreateLogSoftmaxOptions(_fbb, ptr, _rehasher).Union(); + } + default: return 0; } } -inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) - FLATBUFFERS_NOEXCEPT : type(u.type), - value(nullptr) { +inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FLATBUFFERS_NOEXCEPT : type(u.type), value(nullptr) { switch (type) { case BuiltinOptions_Conv2DOptions: { value = new Conv2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_DepthwiseConv2DOptions: { - value = new DepthwiseConv2DOptionsT( - *reinterpret_cast(u.value)); + value = new DepthwiseConv2DOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ConcatEmbeddingsOptions: { - value = new ConcatEmbeddingsOptionsT( - *reinterpret_cast(u.value)); + value = new ConcatEmbeddingsOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LSHProjectionOptions: { - value = new LSHProjectionOptionsT( - *reinterpret_cast(u.value)); + value = new LSHProjectionOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_Pool2DOptions: { @@ -6927,18 +6183,15 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_FullyConnectedOptions: { - value = new FullyConnectedOptionsT( - *reinterpret_cast(u.value)); + value = new FullyConnectedOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SoftmaxOptions: { - value = - new SoftmaxOptionsT(*reinterpret_cast(u.value)); + value = new SoftmaxOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_ConcatenationOptions: { - value = new ConcatenationOptionsT( - *reinterpret_cast(u.value)); + value = new ConcatenationOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_AddOptions: { @@ -6950,8 +6203,7 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_LocalResponseNormalizationOptions: { - value = new LocalResponseNormalizationOptionsT( - *reinterpret_cast(u.value)); + value = new LocalResponseNormalizationOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_LSTMOptions: { @@ -6959,8 +6211,7 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_ResizeBilinearOptions: { - value = new ResizeBilinearOptionsT( - *reinterpret_cast(u.value)); + value = new ResizeBilinearOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_CallOptions: { @@ -6968,23 +6219,19 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_ReshapeOptions: { - value = - new ReshapeOptionsT(*reinterpret_cast(u.value)); + value = new ReshapeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SkipGramOptions: { - value = - new SkipGramOptionsT(*reinterpret_cast(u.value)); + value = new SkipGramOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SpaceToDepthOptions: { - value = new SpaceToDepthOptionsT( - *reinterpret_cast(u.value)); + value = new SpaceToDepthOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_EmbeddingLookupSparseOptions: { - value = new EmbeddingLookupSparseOptionsT( - *reinterpret_cast(u.value)); + value = new EmbeddingLookupSparseOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MulOptions: { @@ -7000,18 +6247,15 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_BatchToSpaceNDOptions: { - value = new BatchToSpaceNDOptionsT( - *reinterpret_cast(u.value)); + value = new BatchToSpaceNDOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SpaceToBatchNDOptions: { - value = new SpaceToBatchNDOptionsT( - *reinterpret_cast(u.value)); + value = new SpaceToBatchNDOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_TransposeOptions: { - value = new TransposeOptionsT( - *reinterpret_cast(u.value)); + value = new TransposeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_MeanOptions: { @@ -7027,18 +6271,31 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) break; } case BuiltinOptions_SqueezeOptions: { - value = - new SqueezeOptionsT(*reinterpret_cast(u.value)); + value = new SqueezeOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_SequenceRNNOptions: { - value = new SequenceRNNOptionsT( - *reinterpret_cast(u.value)); + value = new SequenceRNNOptionsT(*reinterpret_cast(u.value)); break; } case BuiltinOptions_StridedSliceOptions: { - value = new StridedSliceOptionsT( - *reinterpret_cast(u.value)); + value = new StridedSliceOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_ExpOptions: { + value = new ExpOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_TopKV2Options: { + value = new TopKV2OptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_SplitOptions: { + value = new SplitOptionsT(*reinterpret_cast(u.value)); + break; + } + case BuiltinOptions_LogSoftmaxOptions: { + value = new LogSoftmaxOptionsT(*reinterpret_cast(u.value)); break; } default: @@ -7208,8 +6465,27 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } - default: + case BuiltinOptions_ExpOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; break; + } + case BuiltinOptions_TopKV2Options: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_SplitOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + case BuiltinOptions_LogSoftmaxOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } + default: break; } value = nullptr; type = BuiltinOptions_NONE; @@ -7219,25 +6495,33 @@ inline const tflite::Model *GetModel(const void *buf) { return flatbuffers::GetRoot(buf); } -inline const char *ModelIdentifier() { return "TFL3"; } +inline const char *ModelIdentifier() { + return "TFL3"; +} inline bool ModelBufferHasIdentifier(const void *buf) { - return flatbuffers::BufferHasIdentifier(buf, ModelIdentifier()); + return flatbuffers::BufferHasIdentifier( + buf, ModelIdentifier()); } -inline bool VerifyModelBuffer(flatbuffers::Verifier &verifier) { +inline bool VerifyModelBuffer( + flatbuffers::Verifier &verifier) { return verifier.VerifyBuffer(ModelIdentifier()); } -inline const char *ModelExtension() { return "tflite"; } +inline const char *ModelExtension() { + return "tflite"; +} -inline void FinishModelBuffer(flatbuffers::FlatBufferBuilder &fbb, - flatbuffers::Offset root) { +inline void FinishModelBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { fbb.Finish(root, ModelIdentifier()); } inline std::unique_ptr UnPackModel( - const void *buf, const flatbuffers::resolver_function_t *res = nullptr) { + const void *buf, + const flatbuffers::resolver_function_t *res = nullptr) { return std::unique_ptr(GetModel(buf)->UnPack(res)); } diff --git a/tensorflow/contrib/lite/testdata/multi_add.pb b/tensorflow/contrib/lite/testdata/multi_add.pb new file mode 100644 index 0000000000000000000000000000000000000000..e95a20841fb2b320bd77994d9dda157d79311dd6 --- /dev/null +++ b/tensorflow/contrib/lite/testdata/multi_add.pb @@ -0,0 +1,26 @@ + +I +a Placeholder" /device:CPU:0* +shape:* +dtype0 +I +b Placeholder" /device:CPU:0* +dtype0* +shape: +I +c Placeholder" /device:CPU:0* +dtype0* +shape: +I +d Placeholder" /device:CPU:0* +dtype0* +shape: +& +iAddbc" /device:CPU:0* +T0 +& +xAddai" /device:CPU:0* +T0 +& +yAdddi" /device:CPU:0* +T0" \ No newline at end of file diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b949045128fc15b6abe8f6c59d63dfd2b47c3c30..83b9e2142798c685cbc8e1fd4d1db5c40b70389f 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -25,6 +25,7 @@ gen_zipped_test_files( "conv.zip", "depthwiseconv.zip", "div.zip", + "exp.zip", "fully_connected.zip", "fused_batch_norm.zip", "gather.zip", @@ -32,6 +33,8 @@ gen_zipped_test_files( "l2_pool.zip", "l2norm.zip", "local_response_norm.zip", + "log_softmax.zip", + "lstm.zip", "max_pool.zip", "mean.zip", "mul.zip", @@ -45,9 +48,11 @@ gen_zipped_test_files( "softmax.zip", "space_to_batch_nd.zip", "space_to_depth.zip", + "split.zip", "squeeze.zip", "strided_slice.zip", "sub.zip", + "topk.zip", "transpose.zip", ], ) @@ -121,6 +126,21 @@ cc_test( ], ) +cc_library( + name = "join", + hdrs = ["join.h"], +) + +cc_test( + name = "join_test", + size = "small", + srcs = ["join_test.cc"], + deps = [ + ":join", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "tflite_driver", srcs = ["tflite_driver.cc"], @@ -195,6 +215,118 @@ cc_binary( ], ) +cc_library( + name = "tf_driver", + srcs = ["tf_driver.cc"], + hdrs = ["tf_driver.h"], + deps = [ + ":join", + ":split", + ":test_runner", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + ], +) + +cc_test( + name = "tf_driver_test", + size = "small", + srcs = ["tf_driver_test.cc"], + data = ["//tensorflow/contrib/lite:testdata/multi_add.pb"], + deps = [ + ":tf_driver", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "generate_testspec", + testonly = 1, + srcs = ["generate_testspec.cc"], + hdrs = ["generate_testspec.h"], + deps = [ + ":join", + ":split", + ":tf_driver", + "//tensorflow/core:framework", + ], +) + +cc_test( + name = "generate_testspec_test", + size = "small", + srcs = ["generate_testspec_test.cc"], + deps = [ + ":generate_testspec", + "@com_google_googletest//:gtest_main", + ], +) + +cc_library( + name = "tflite_diff_util", + testonly = 1, + srcs = ["tflite_diff_util.cc"], + hdrs = ["tflite_diff_util.h"], + deps = [ + ":generate_testspec", + ":parse_testdata_lib", + ":split", + ":tflite_driver", + ":util", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "tflite_diff_flags", + testonly = 1, + hdrs = ["tflite_diff_flags.h"], + deps = [ + ":split", + ":tflite_diff_util", + ] + select({ + "//conditions:default": [ + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], + "//tensorflow:android": [ + "//tensorflow/core:android_tensorflow_lib", + ], + }), +) + +tf_cc_test( + name = "tflite_diff_example_test", + size = "medium", + srcs = ["tflite_diff_example_test.cc"], + args = [ + "--tensorflow_model=third_party/tensorflow/contrib/lite/testdata/multi_add.pb", + "--tflite_model=third_party/tensorflow/contrib/lite/testdata/multi_add.bin", + "--input_layer=a,b,c,d", + "--input_layer_type=float,float,float,float", + "--input_layer_shape=1,3,4,3:1,3,4,3:1,3,4,3:1,3,4,3", + "--output_layer=x,y", + ], + data = [ + "//tensorflow/contrib/lite:testdata/multi_add.bin", + "//tensorflow/contrib/lite:testdata/multi_add.pb", + ], + tags = [ + "no_cuda_on_cpu_tap", + "no_oss", + ], + deps = [ + ":tflite_diff_flags", + ":tflite_diff_util", + ], +) + tf_cc_test( name = "generated_examples_zip_test", size = "large", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 147400ec37c606308244cc862bbee5b88ba553ec..5488b71fcf644070710acc4b2b2886e9a96facb6 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -36,6 +36,7 @@ import traceback import zipfile import numpy as np from six import StringIO +from six.moves import xrange # TODO(aselle): Disable GPU for now os.environ["CUDA_VISIBLE_DEVICES"] = "-1" @@ -46,6 +47,7 @@ from google.protobuf import text_format # TODO(aselle): switch to TensorFlow's resource_loader from tensorflow.contrib.lite.testing import generate_examples_report as report_lib from tensorflow.python.framework import graph_util as tf_graph_util +from tensorflow.python.ops import rnn parser = argparse.ArgumentParser(description="Script to generate TFLite tests.") parser.add_argument("output_path", @@ -99,16 +101,32 @@ KNOWN_BUGS = { r"batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\]": "70594634", # BatchToSpaceND only supports 4D tensors. r"batch_to_space_nd.*input_shape=\[8,2,2,2,1,1\]": "70594733", - # Div will use floordiv - r"div.*int32": "72051395" + # Div will use floordiv. + r"div.*int32": "72051395", + # TOCO require matching dimensions in strided_slice. + r"strided_slice.*begin=\[0\].*end=\[1\].*": "73170889", + # No support for SplitV + r"split.*num_or_size_splits=\[2,2\]": "73377559", } +class ExtraTocoOptions(object): + """Additonal toco options besides input, output, shape.""" + + def __init__(self): + # Whether to ignore control dependency nodes. + self.drop_control_dependency = False + # Allow custom ops in the toco conversion. + self.allow_custom_ops = False + # Rnn states that are used to support rnn / lstm cells. + self.rnn_states = None + + def toco_options(data_types, input_arrays, output_arrays, shapes, - drop_control_dependency): + extra_toco_options=ExtraTocoOptions()): """Create TOCO options to process a model. Args: @@ -116,8 +134,7 @@ def toco_options(data_types, input_arrays: names of the input tensors output_arrays: name of the output tensors shapes: shapes of the input tensors - drop_control_dependency: whether to ignore control dependency nodes. - + extra_toco_options: additional toco options Returns: the options in a string. """ @@ -133,37 +150,15 @@ def toco_options(data_types, " --input_arrays=%s" % ",".join(input_arrays) + " --input_shapes=%s" % shape_str + " --output_arrays=%s" % ",".join(output_arrays)) - if drop_control_dependency: + if extra_toco_options.drop_control_dependency: s += " --drop_control_dependency" + if extra_toco_options.allow_custom_ops: + s += " --allow_custom_ops" + if extra_toco_options.rnn_states: + s += (" --rnn_states='" + extra_toco_options.rnn_states + "'") return s -def write_toco_options(filename, - data_types, - input_arrays, - output_arrays, - shapes, - drop_control_dependency=False): - """Create TOCO options to process a model. - - Args: - filename: Filename to write the options to. - data_types: input and inference types used by TOCO. - input_arrays: names of the input tensors - output_arrays: names of the output tensors - shapes: shapes of the input tensors - drop_control_dependency: whether to ignore control dependency nodes. - """ - with open(filename, "w") as fp: - fp.write( - toco_options( - data_types=data_types, - input_arrays=input_arrays, - output_arrays=output_arrays, - shapes=shapes, - drop_control_dependency=drop_control_dependency)) - - def write_examples(fp, examples): """Given a list `examples`, write a text format representation. @@ -241,7 +236,7 @@ def create_tensor_data(dtype, shape, min_value=-100, max_value=100): if dtype in (tf.float32, tf.float16): value = (max_value-min_value)*np.random.random_sample(shape)+min_value elif dtype in (tf.int32, tf.uint8, tf.int64): - value = np.random.random_integers(min_value, max_value, shape) + value = np.random.randint(min_value, max_value+1, shape) return value.astype(dtype) @@ -281,12 +276,14 @@ def make_control_dep_tests(zip_path): return [input_values], sess.run( outputs, feed_dict=dict(zip(inputs, [input_values]))) + extra_toco_options = ExtraTocoOptions() + extra_toco_options.drop_control_dependency = True make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs, - drop_control_dependency=True) + extra_toco_options) def toco_convert(graph_def_str, input_tensors, output_tensors, - drop_control_dependency=False): + extra_toco_options): """Convert a model's graph def into a tflite model. NOTE: this currently shells out to the toco binary, but we would like @@ -294,9 +291,9 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, Args: graph_def_str: Graph def proto in serialized string format. - input_tensors: List of input tensor tuples `(name, shape, type)` - output_tensors: List of output tensors (names) - drop_control_dependency: whether to ignore control dependency nodes. + input_tensors: List of input tensor tuples `(name, shape, type)`. + output_tensors: List of output tensors (names). + extra_toco_options: Additional toco options. Returns: output tflite model, log_txt from conversion @@ -308,7 +305,7 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, input_arrays=[x[0] for x in input_tensors], shapes=[x[1] for x in input_tensors], output_arrays=output_tensors, - drop_control_dependency=drop_control_dependency) + extra_toco_options=extra_toco_options) with tempfile.NamedTemporaryFile() as graphdef_file, \ tempfile.NamedTemporaryFile() as output_file, \ @@ -327,11 +324,18 @@ def toco_convert(graph_def_str, input_tensors, output_tensors, return (None if exit_code != 0 else output_file.read()), log +def normalize_output_name(output_name): + """Remove :0 suffix from tensor names.""" + return output_name.split(":")[0] if output_name.endswith( + ":0") else output_name + + def make_zip_of_tests(zip_path, test_parameters, make_graph, make_test_inputs, - drop_control_dependency=False): + extra_toco_options=ExtraTocoOptions(), + use_frozen_graph=False): """Helper to make a zip file of a bunch of TensorFlow models. This does a cartestian product of the dictionary of test_parameters and @@ -349,7 +353,9 @@ def make_zip_of_tests(zip_path, `[input1, input2, ...], [output1, output2, ...]` make_test_inputs: function taking `curr_params`, `session`, `input_tensors`, `output_tensors` and returns tuple `(input_values, output_values)`. - drop_control_dependency: whether to ignore control dependency nodes. + extra_toco_options: Additional toco options. + use_frozen_graph: Whether or not freeze graph before toco converter. + Raises: RuntimeError: if there are toco errors that can't be ignored. """ @@ -409,21 +415,25 @@ def make_zip_of_tests(zip_path, return None, report report["toco"] = report_lib.FAILED report["tf"] = report_lib.SUCCESS - # Convert graph to toco + input_tensors = [(input_tensor.name.split(":")[0], + input_tensor.get_shape(), input_tensor.dtype) + for input_tensor in inputs] + output_tensors = [normalize_output_name(out.name) for out in outputs] + graph_def = freeze_graph( + sess, + tf.global_variables() + inputs + + outputs) if use_frozen_graph else sess.graph_def tflite_model_binary, toco_log = toco_convert( - sess.graph_def.SerializeToString(), - [(input_tensor.name.split(":")[0], input_tensor.get_shape(), - input_tensor.dtype) for input_tensor in inputs], - [out.name.split(":")[0] - for out in outputs], drop_control_dependency) + graph_def.SerializeToString(), input_tensors, output_tensors, + extra_toco_options) report["toco"] = (report_lib.SUCCESS if tflite_model_binary is not None else report_lib.FAILED) report["toco_log"] = toco_log if FLAGS.save_graphdefs: archive.writestr(label + ".pb", - text_format.MessageToString(sess.graph_def), + text_format.MessageToString(graph_def), zipfile.ZIP_DEFLATED) if tflite_model_binary: @@ -745,6 +755,65 @@ def make_mean_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_exp_tests(zip_path): + """Make a set of tests to do exp.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]], + }] + + def build_graph(parameters): + """Build the exp op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + out = tf.exp(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data(parameters["input_dtype"], parameters["input_shape"], + min_value=-100, max_value=9) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + +def make_log_softmax_tests(zip_path): + """Make a set of tests to do log_softmax.""" + + test_parameters = [{ + "input_dtype": [tf.float32], + "input_shape": [[1, 100], [4, 2], [5, 224]], + }] + + def build_graph(parameters): + """Build the log_softmax op testing graph.""" + input_tensor = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + + out = tf.nn.log_softmax(input_tensor) + return [input_tensor], [out] + + def build_inputs(parameters, sess, inputs, outputs): + values = [ + create_tensor_data( + parameters["input_dtype"], + parameters["input_shape"], + min_value=-100, + max_value=9) + ] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_binary_op_tests_func(binary_operator): """Return a function that does a test on a binary operator.""" return lambda zip_path: make_binary_op_tests(zip_path, binary_operator) @@ -995,8 +1064,31 @@ def make_depthwiseconv_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_split_tests(zip_path): + """Make a set of tests to do tf.split.""" + + test_parameters = [{ + "input_shape": [[1, 3, 4, 6], [2, 4, 1], [6, 4], [8]], + "num_or_size_splits": [1, 2, 3, 4, 5, [2, 2]], + "axis": [0, 1, 2, 3, -4, -3, -2, -1], + }] + + def build_graph(parameters): + input_tensor = tf.placeholder( + dtype=tf.float32, name="input", shape=parameters["input_shape"]) + out = tf.split( + input_tensor, parameters["num_or_size_splits"], parameters["axis"]) + return [input_tensor], out + + def build_inputs(parameters, sess, inputs, outputs): + values = [create_tensor_data(np.float32, parameters["input_shape"])] + return values, sess.run(outputs, feed_dict=dict(zip(inputs, values))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_concatenation_tests(zip_path): - """Make a set of tests to do concatenatinon.""" + """Make a set of tests to do concatenation.""" test_parameters = [{ "base_shape": [[1, 3, 4, 3], [3, 4]], @@ -1568,6 +1660,19 @@ def make_strided_slice_tests(zip_path): "shrink_axis_mask": [None, 1, 8, 11, 15, -1], "constant_indices": [False, True], }, + # + { + "dtype": [tf.float32], + "index_type": [tf.int32], + "input_shape": [[12, 2, 2, 5]], + "begin": [[0]], + "end": [[1]], + "strides": [[1]], + "begin_mask": [0], + "end_mask": [0], + "shrink_axis_mask": [1], + "constant_indices": [True], + }, # 2-D { "dtype": [tf.float32, tf.int32, tf.int64], @@ -1583,7 +1688,7 @@ def make_strided_slice_tests(zip_path): }, # Negative strides { - "dtype": [tf.float32, tf.int32, tf.int64], + "dtype": [tf.float32], "index_type": [tf.int32], "input_shape": [[2, 3]], "begin": [[0, -1]], @@ -1656,6 +1761,84 @@ def make_strided_slice_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_lstm_tests(zip_path): + """Make a set of tests to do basic Lstm cell.""" + + test_parameters = [ + { + "dtype": [tf.float32], + "num_batchs": [1], + "time_step_size": [1], + "input_vec_size": [3], + "num_cells": [4], + }, + ] + + def build_graph(parameters): + """Build a simple graph with BasicLSTMCell.""" + + num_batchs = parameters["num_batchs"] + time_step_size = parameters["time_step_size"] + input_vec_size = parameters["input_vec_size"] + num_cells = parameters["num_cells"] + inputs_after_split = [] + for i in xrange(time_step_size): + one_timestamp_input = tf.placeholder( + dtype=parameters["dtype"], + name="split_{}".format(i), + shape=[num_batchs, input_vec_size]) + inputs_after_split.append(one_timestamp_input) + # Currently lstm identifier has a few limitations: only supports + # forget_bias == 0, inner state activiation == tanh. + # TODO(zhixianyan): Add another test with forget_bias == 1. + # TODO(zhixianyan): Add another test with relu as activation. + lstm_cell = tf.contrib.rnn.BasicLSTMCell( + num_cells, forget_bias=0.0, state_is_tuple=True) + cell_outputs, _ = rnn.static_rnn( + lstm_cell, inputs_after_split, dtype=tf.float32) + out = cell_outputs[-1] + return inputs_after_split, [out] + + def build_inputs(parameters, sess, inputs, outputs): + """Feed inputs, assign vairables, and freeze graph.""" + + with tf.variable_scope("", reuse=True): + kernel = tf.get_variable("rnn/basic_lstm_cell/kernel") + bias = tf.get_variable("rnn/basic_lstm_cell/bias") + kernel_values = create_tensor_data( + parameters["dtype"], [kernel.shape[0], kernel.shape[1]], -1, 1) + bias_values = create_tensor_data(parameters["dtype"], [bias.shape[0]], 0, + 1) + sess.run(tf.group(kernel.assign(kernel_values), bias.assign(bias_values))) + + num_batchs = parameters["num_batchs"] + time_step_size = parameters["time_step_size"] + input_vec_size = parameters["input_vec_size"] + input_values = [] + for _ in xrange(time_step_size): + tensor_data = create_tensor_data(parameters["dtype"], + [num_batchs, input_vec_size], 0, 1) + input_values.append(tensor_data) + out = sess.run(outputs, feed_dict=dict(zip(inputs, input_values))) + return input_values, out + + # TODO(zhixianyan): Automatically generate rnn_states for lstm cell. + extra_toco_options = ExtraTocoOptions() + extra_toco_options.rnn_states = ( + "{state_array:rnn/BasicLSTMCellZeroState/zeros," + "back_edge_source_array:rnn/basic_lstm_cell/Add_1,size:4}," + "{state_array:rnn/BasicLSTMCellZeroState/zeros_1," + "back_edge_source_array:rnn/basic_lstm_cell/Mul_2,size:4}") + + make_zip_of_tests( + zip_path, + test_parameters, + build_graph, + build_inputs, + extra_toco_options, + use_frozen_graph=True) + + def make_l2_pool(input_tensor, ksize, strides, padding, data_format): """Given an input perform a sequence of TensorFlow ops to produce l2pool.""" return tf.sqrt(tf.nn.avg_pool( @@ -1663,6 +1846,32 @@ def make_l2_pool(input_tensor, ksize, strides, padding, data_format): padding=padding, data_format=data_format)) +def make_topk_tests(zip_path): + """Make a set of tests to do gather.""" + + test_parameters = [{ + "input_dtype": [tf.float32, tf.int32], + "input_shape": [[10], [5, 20]], + }] + + def build_graph(parameters): + """Build the gather op testing graph.""" + input_value = tf.placeholder( + dtype=parameters["input_dtype"], + name="input", + shape=parameters["input_shape"]) + k = tf.constant(3, name="k") + out = tf.nn.top_k(input_value, k) + return [input_value], [out[1]] + + def build_inputs(parameters, sess, inputs, outputs): + input_value = create_tensor_data(parameters["input_dtype"], + parameters["input_shape"]) + return [input_value], sess.run( + outputs, feed_dict=dict(zip(inputs, [input_value]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + # Toco binary path provided by the generate rule. bin_path = None @@ -1711,10 +1920,15 @@ def main(unused_args): "sigmoid.zip": make_sigmoid_tests, "softmax.zip": make_softmax_tests, "space_to_depth.zip": make_space_to_depth_tests, + "topk.zip": make_topk_tests, + "split.zip": make_split_tests, "transpose.zip": make_transpose_tests, "mean.zip": make_mean_tests, "squeeze.zip": make_squeeze_tests, "strided_slice.zip": make_strided_slice_tests, + "exp.zip": make_exp_tests, + "log_softmax.zip": make_log_softmax_tests, + "lstm.zip": make_lstm_tests, } out = FLAGS.zip_to_output bin_path = FLAGS.toco diff --git a/tensorflow/contrib/lite/testing/generate_testspec.cc b/tensorflow/contrib/lite/testing/generate_testspec.cc new file mode 100644 index 0000000000000000000000000000000000000000..eb3deafb6986e877f0a553a8b6f712102af4caca --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_testspec.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/lite/testing/generate_testspec.h" +#include "tensorflow/contrib/lite/testing/join.h" +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/tf_driver.h" +#include "tensorflow/core/framework/types.h" + +namespace tflite { +namespace testing { + +void GenerateTestSpecFromTensorflowModel( + std::iostream& stream, const string& tensorflow_model_path, + const string& tflite_model_path, const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); + + // Initialize random functions. + static unsigned int seed = 0; + std::function float_rand = [](int idx) { + return static_cast(rand_r(&seed)) / RAND_MAX - 0.5f; + }; + + // Generate inputs. + std::vector input_values; + input_values.resize(input_layer.size()); + for (int i = 0; i < input_layer.size(); i++) { + tensorflow::DataType type; + CHECK(DataTypeFromString(input_layer_type[i], &type)); + auto shape = Split(input_layer_shape[i], ","); + + switch (type) { + case tensorflow::DT_FLOAT: { + const auto& data = GenerateRandomTensor(shape, float_rand); + input_values[i] = Join(data.data(), data.size(), ","); + break; + } + default: + + fprintf(stderr, "Unsupported type %d when generating testspec\n", type); + return; + } + } + + // Invoke tensorflow model. + TfDriver runner(input_layer, input_layer_type, input_layer_shape, + output_layer); + runner.LoadModel(tensorflow_model_path); + for (int i = 0; i < input_values.size(); i++) { + runner.SetInput(i, input_values[i]); + } + runner.Invoke(); + + // Write test spec. + stream << "load_model: " << tflite_model_path << "\n"; + stream << "reshape {\n"; + for (const auto& shape : input_layer_shape) { + stream << " input: \"" << shape << "\"\n"; + } + stream << "}\n"; + stream << "invoke {\n"; + for (const auto& value : input_values) { + stream << " input: \"" << value << "\"\n"; + } + for (int i = 0; i < output_layer.size(); i++) { + stream << " output: \"" << runner.ReadOutput(i) << "\"\n"; + } + stream << "}\n"; +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/generate_testspec.h b/tensorflow/contrib/lite/testing/generate_testspec.h new file mode 100644 index 0000000000000000000000000000000000000000..3529ee709b66625fff6e2a35b78e47f3778f0fe7 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_testspec.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_GENERATE_TESTSPEC_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_GENERATE_TESTSPEC_H_ + +#include +#include +#include + +namespace tflite { +namespace testing { + +// Generate test spec by executing TensorFlow model on random inputs. +// The test spec can be consumed by ParseAndRunTests. +// See test spec format in parse_testdata.h +// +// Inputs: +// stream: mutable iostream that contains the contents of test spec. +// tensorflow_model_path: path to TensorFlow model. +// tflite_model_path: path to tflite_model_path that the test spec runs +// against. input_layer: names of input tensors. Example: input1 +// input_layer_type: datatypes of input tensors. Example: float +// input_layer_shape: shapes of input tensors, separated by comma. example: +// 1,3,4 output_layer: names of output tensors. Example: output +void GenerateTestSpecFromTensorflowModel( + std::iostream& stream, const string& tensorflow_model_path, + const string& tflite_model_path, const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer); + +// Generates random values that are filled into the tensor. +// random_func returns the generated random element at given index. +template +std::vector GenerateRandomTensor(const std::vector& shape, + const std::function& random_func) { + int64_t num_elements = 1; + for (const int dim : shape) { + num_elements *= dim; + } + + std::vector result(num_elements); + for (int i = 0; i < num_elements; i++) { + result[i] = random_func(i); + } + return result; +} + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_GENERATE_TESTSPEC_H_ diff --git a/tensorflow/contrib/lite/testing/generate_testspec_test.cc b/tensorflow/contrib/lite/testing/generate_testspec_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..2a97b757a413246c9ad9b5f453741b13e381c903 --- /dev/null +++ b/tensorflow/contrib/lite/testing/generate_testspec_test.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/generate_testspec.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +TEST(GenerateRandomTensor, FloatValue) { + static unsigned int seed = 0; + std::function float_rand = [](int idx) { + return static_cast(rand_r(&seed)) / RAND_MAX - 0.5f; + }; + + std::set values; + float sum_x_square = 0.0f; + float sum_x = 0.0f; + for (int i = 0; i < 100; i++) { + const auto& data = GenerateRandomTensor({1, 3, 4}, float_rand); + for (float value : data) { + values.insert(value); + sum_x_square += value * value; + sum_x += value; + } + } + + // Eech round, generated tensor has different values. + EXPECT_GT(values.size(), 200); + int num = 1 * 3 * 4 * 100; + float stddev = sum_x_square / num - (sum_x / num) * (sum_x / num); + + // Stddev is greater than 1/2 stddev of uniform distribution: (B-A)^2 / 12 + float minstddev = 1.0f / 12 / 2; + EXPECT_GT(stddev, minstddev); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 5ea3e21f6a1636d1e7029bed8e75b2f68f656103..86606d12393b94567fbe1fceb6d708b266efe4a8 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -92,6 +92,9 @@ std::map kBrokenTests = { // Transpose only supports 1D-4D input tensors. {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"}, + + // Lstm kernel gets different results on tsan, asan, msan. + {R"(^\/lstmdtype=tf.float32.*)", "73830845"}, }; // Allows test data to be unzipped into a temporary directory and makes @@ -242,6 +245,7 @@ INSTANTIATE_TESTS(constant) INSTANTIATE_TESTS(control_dep) INSTANTIATE_TESTS(conv) INSTANTIATE_TESTS(depthwiseconv) +INSTANTIATE_TESTS(exp) INSTANTIATE_TESTS(fully_connected) INSTANTIATE_TESTS(fused_batch_norm) INSTANTIATE_TESTS(gather) @@ -249,6 +253,7 @@ INSTANTIATE_TESTS(global_batch_norm) INSTANTIATE_TESTS(l2norm) INSTANTIATE_TESTS(l2_pool) INSTANTIATE_TESTS(local_response_norm) +INSTANTIATE_TESTS(log_softmax) INSTANTIATE_TESTS(max_pool) INSTANTIATE_TESTS(mul) INSTANTIATE_TESTS(pad) @@ -261,8 +266,10 @@ INSTANTIATE_TESTS(sigmoid) INSTANTIATE_TESTS(softmax) INSTANTIATE_TESTS(space_to_depth) INSTANTIATE_TESTS(sub) +INSTANTIATE_TESTS(split) INSTANTIATE_TESTS(div) INSTANTIATE_TESTS(transpose) +INSTANTIATE_TESTS(lstm) INSTANTIATE_TESTS(mean) INSTANTIATE_TESTS(squeeze) INSTANTIATE_TESTS(strided_slice) diff --git a/tensorflow/compiler/xla/array2d.cc b/tensorflow/contrib/lite/testing/join.h similarity index 51% rename from tensorflow/compiler/xla/array2d.cc rename to tensorflow/contrib/lite/testing/join.h index 418587c1f75c7249f92e925455d40685d870c57a..ce8c072a21c6e61e8ab8ae12ba52418e6144009a 100644 --- a/tensorflow/compiler/xla/array2d.cc +++ b/tensorflow/contrib/lite/testing/join.h @@ -12,25 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ -#include "tensorflow/compiler/xla/array2d.h" -#include "tensorflow/compiler/xla/ptr_util.h" - -namespace xla { - -std::unique_ptr> MakeLinspaceArray2D(float from, float to, - int64 n1, int64 n2) { - auto array = MakeUnique>(n1, n2); - int64 count = n1 * n2; - float step = (count > 1) ? (to - from) / (count - 1) : 0.0f; - auto set = [&array, n1, n2](int64 index, float value) { - (*array)(index / n2, index % n2) = value; - }; - for (int64 i = 0; i < count - 1; ++i) { - set(i, from + i * step); +#include +#include +#include + +namespace tflite { +namespace testing { + +// Join a list of data separated by delimieter. +template +string Join(T* data, size_t len, const string& delimiter) { + if (len == 0 || data == nullptr) { + return ""; } - set(count - 1, to); - return array; + std::stringstream result; + result << data[0]; + for (int i = 1; i < len; i++) { + result << delimiter << data[i]; + } + return result.str(); } -} // namespace xla +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_JOIN_H_ diff --git a/tensorflow/contrib/lite/testing/join_test.cc b/tensorflow/contrib/lite/testing/join_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd04528381f6d31164728a5cabbf8753e9b8d2b8 --- /dev/null +++ b/tensorflow/contrib/lite/testing/join_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/join.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +TEST(JoinTest, JoinInt) { + std::vector data = {1, 2, 3}; + EXPECT_EQ(Join(data.data(), data.size(), ","), "1,2,3"); +} + +TEST(JoinTest, JoinFloat) { + float data[] = {1.0, -3, 2.3, 1e-5}; + EXPECT_EQ(Join(data, 4, " "), "1 -3 2.3 1e-05"); +} + +TEST(JoinTest, JoinNullData) { EXPECT_THAT(Join(nullptr, 3, ","), ""); } + +TEST(JoinTest, JoinZeroData) { + std::vector data; + EXPECT_THAT(Join(data.data(), 0, ","), ""); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/parse_testdata.cc b/tensorflow/contrib/lite/testing/parse_testdata.cc index 0caef0fe2201a668b2235a98304eb353072a3c2f..389688d552051ea735ce71533943af33df5059ef 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.cc +++ b/tensorflow/contrib/lite/testing/parse_testdata.cc @@ -192,27 +192,25 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, int model_outputs = interpreter->outputs().size(); TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size()); for (size_t i = 0; i < interpreter->outputs().size(); i++) { + bool tensors_differ = false; int output_index = interpreter->outputs()[i]; if (const float* data = interpreter->typed_tensor(output_index)) { for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { float computed = data[idx]; float reference = example.outputs[0].flat_data[idx]; float diff = std::abs(computed - reference); - bool error_is_large = false; // For very small numbers, try absolute error, otherwise go with // relative. - if (std::abs(reference) < kRelativeThreshold) { - error_is_large = (diff > kAbsoluteThreshold); - } else { - error_is_large = (diff > kRelativeThreshold * std::abs(reference)); - } - if (error_is_large) { + bool local_tensors_differ = + std::abs(reference) < kRelativeThreshold + ? diff > kAbsoluteThreshold + : diff > kRelativeThreshold * std::abs(reference); + if (local_tensors_differ) { fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n", i, idx, data[idx], reference); - return kTfLiteError; + tensors_differ = local_tensors_differ; } } - fprintf(stderr, "\n"); } else if (const int32_t* data = interpreter->typed_tensor(output_index)) { for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { @@ -221,10 +219,9 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, if (std::abs(computed - reference) > 0) { fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n", i, idx, computed, reference); - return kTfLiteError; + tensors_differ = true; } } - fprintf(stderr, "\n"); } else if (const int64_t* data = interpreter->typed_tensor(output_index)) { for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) { @@ -235,14 +232,15 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, "output[%zu][%zu] did not match %" PRId64 " vs reference %" PRId64 "\n", i, idx, computed, reference); - return kTfLiteError; + tensors_differ = true; } } - fprintf(stderr, "\n"); } else { fprintf(stderr, "output[%zu] was not float or int data\n", i); return kTfLiteError; } + fprintf(stderr, "\n"); + if (tensors_differ) return kTfLiteError; } return kTfLiteOk; } @@ -319,8 +317,9 @@ class Reshape : public Message { // This is the top-level message in a test file. class TestData : public Message { public: - explicit TestData(TestRunner* test_runner) : test_runner_(test_runner) {} - + explicit TestData(TestRunner* test_runner) + : test_runner_(test_runner), num_invocations_(0), max_invocations_(-1) {} + void SetMaxInvocations(int max) { max_invocations_ = max; } void SetField(const std::string& name, const std::string& value) override { if (name == "load_model") { test_runner_->LoadModel(value); @@ -334,7 +333,12 @@ class TestData : public Message { Message* AddChild(const std::string& s) override { if (s == "invoke") { test_runner_->AllocateTensors(); - return Store(new Invoke(test_runner_)); + if (max_invocations_ == -1 || num_invocations_ < max_invocations_) { + ++num_invocations_; + return Store(new Invoke(test_runner_)); + } else { + return nullptr; + } } else if (s == "reshape") { return Store(new Reshape(test_runner_)); } @@ -343,10 +347,14 @@ class TestData : public Message { private: TestRunner* test_runner_; + int num_invocations_; + int max_invocations_; }; -bool ParseAndRunTests(std::istream* input, TestRunner* test_runner) { +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner, + int max_invocations) { TestData test_data(test_runner); + test_data.SetMaxInvocations(max_invocations); Message::Read(input, &test_data); return test_runner->IsValid() && test_runner->GetOverallSuccess(); } diff --git a/tensorflow/contrib/lite/testing/parse_testdata.h b/tensorflow/contrib/lite/testing/parse_testdata.h index 7ebf362eb99c5f4cf6ea3654cf71e13ff1de99b3..d94361d735e2be8dc130dc8d6bf0bb5c822ebb7c 100644 --- a/tensorflow/contrib/lite/testing/parse_testdata.h +++ b/tensorflow/contrib/lite/testing/parse_testdata.h @@ -66,7 +66,8 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter, const Example&); // output: "12,3,4,545,3" // output: "0.01,0.02" // } -bool ParseAndRunTests(std::istream* input, TestRunner* test_runner); +bool ParseAndRunTests(std::istream* input, TestRunner* test_runner, + int max_invocations = -1); } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/test_runner.h b/tensorflow/contrib/lite/testing/test_runner.h index 60eaafa474a01887bee12b031b1f59cc5c91f173..05770beee23275ebe210606dbfd2b33eea17612d 100644 --- a/tensorflow/contrib/lite/testing/test_runner.h +++ b/tensorflow/contrib/lite/testing/test_runner.h @@ -68,6 +68,10 @@ class TestRunner { // satisfied. virtual bool CheckResults() = 0; + // Read contents of tensor into csv format. + // The given 'id' is guaranteed to be one of the ids returned by GetOutputs(). + virtual string ReadOutput(int id) = 0; + // Set the base path for loading models. void SetModelBaseDir(const string& path) { model_base_dir_ = path; diff --git a/tensorflow/contrib/lite/testing/test_runner_test.cc b/tensorflow/contrib/lite/testing/test_runner_test.cc index f712a5347a042990ae5adb9d44325dd683193168..3f04aa20bd7de813f0acd3f5897d5ab2df6c0fd7 100644 --- a/tensorflow/contrib/lite/testing/test_runner_test.cc +++ b/tensorflow/contrib/lite/testing/test_runner_test.cc @@ -31,6 +31,7 @@ class ConcreteTestRunner : public TestRunner { void ResetTensor(int id) override {} void SetInput(int id, const string& csv_values) override {} void SetExpectation(int id, const string& csv_values) override {} + string ReadOutput(int id) override { return ""; } void Invoke() override {} bool CheckResults() override { return true; } bool CheckFloatSizes(size_t bytes, size_t values) { diff --git a/tensorflow/contrib/lite/testing/tf_driver.cc b/tensorflow/contrib/lite/testing/tf_driver.cc new file mode 100644 index 0000000000000000000000000000000000000000..2c253bb1983e5ddc5bc12858c929585d1bcee710 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver.cc @@ -0,0 +1,182 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tf_driver.h" + +#include +#include + +#include "tensorflow/contrib/lite/testing/join.h" +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tflite { +namespace testing { + +namespace { + +tensorflow::Tensor CreateTensor(const tensorflow::DataType type, + const std::vector& dim) { + tensorflow::TensorShape shape{gtl::ArraySlice{ + reinterpret_cast(dim.data()), dim.size()}}; + return {type, shape}; +} + +template +void FillTensorWithData(tensorflow::Tensor* tensor, const string& csv_values) { + auto data = tensor->flat(); + + const auto& values = testing::Split(csv_values, ","); + for (int i = 0; i < values.size(); i++) { + data(i) = values[i]; + } +} + +template +void FillTensorWithZeros(tensorflow::Tensor* tensor) { + auto data = tensor->flat(); + for (int i = 0; i < tensor->NumElements(); i++) { + data(i) = 0; + } +} + +template +string TensorDataToCsvString(const tensorflow::Tensor& tensor) { + const auto& data = tensor.flat(); + return Join(data.data(), data.size(), ","); +} + +} // namespace + +TfDriver::TfDriver(const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer) + : input_names_(input_layer), output_names_(output_layer) { + CHECK_EQ(input_layer.size(), input_layer_type.size()); + CHECK_EQ(input_layer.size(), input_layer_shape.size()); + + input_ids_.resize(input_layer.size()); + input_tensors_.reserve(input_layer.size()); + input_types_.resize(input_layer.size()); + input_shapes_.resize(input_layer.size()); + for (int i = 0; i < input_layer.size(); i++) { + input_ids_[i] = i; + input_tensors_[input_layer[i]] = {}; + CHECK(DataTypeFromString(input_layer_type[i], &input_types_[i])); + input_shapes_[i] = Split(input_layer_shape[i], ","); + } + + output_ids_.resize(output_layer.size()); + output_tensors_.reserve(output_layer.size()); + for (int i = 0; i < output_layer.size(); i++) { + output_ids_[i] = i; + } +} + +void TfDriver::LoadModel(const string& bin_file_path) { + if (!IsValid()) return; + std::cout << std::endl << "Loading model: " << bin_file_path << std::endl; + std::ifstream model(bin_file_path); + if (model.fail()) { + Invalidate("Failed to find the model"); + return; + } + + tensorflow::GraphDef graphdef; + if (!graphdef.ParseFromIstream(&model)) { + Invalidate("Failed to parse tensorflow graphdef"); + return; + } + + tensorflow::SessionOptions options; + session_.reset(tensorflow::NewSession(options)); + auto status = session_->Create(graphdef); + if (!status.ok()) { + Invalidate("Failed to create session"); + } +} + +void TfDriver::SetInput(int id, const string& csv_values) { + if (!IsValid()) return; + + auto tensor = CreateTensor(input_types_[id], input_shapes_[id]); + switch (input_types_[id]) { + case tensorflow::DT_FLOAT: { + FillTensorWithData(&tensor, csv_values); + break; + } + case tensorflow::DT_INT32: { + FillTensorWithData(&tensor, csv_values); + break; + } + default: + fprintf(stderr, "Unsupported type %d in SetInput\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return; + } + input_tensors_[input_names_[id]] = tensor; +} + +void TfDriver::ResetTensor(int id) { + if (!IsValid()) return; + auto tensor = input_tensors_[input_names_[id]]; + switch (input_types_[id]) { + case tensorflow::DT_FLOAT: { + FillTensorWithZeros(&tensor); + break; + } + case tensorflow::DT_INT32: { + FillTensorWithZeros(&tensor); + break; + } + default: + fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return; + } +} + +void TfDriver::ReshapeTensor(int id, const string& csv_values) { + input_shapes_[id] = Split(csv_values, ","); + input_tensors_[input_names_[id]] = + CreateTensor(input_types_[id], input_shapes_[id]); + ResetTensor(id); +} + +string TfDriver::ReadOutput(int id) { + if (!IsValid()) return ""; + switch (output_tensors_[id].dtype()) { + case tensorflow::DT_FLOAT: + return TensorDataToCsvString(output_tensors_[id]); + case tensorflow::DT_INT32: + return TensorDataToCsvString(output_tensors_[id]); + default: + fprintf(stderr, "Unsupported type %d in ResetTensor\n", input_types_[id]); + Invalidate("Unsupported tensor data type"); + return ""; + } +} + +void TfDriver::Invoke() { + if (!IsValid()) return; + auto status = session_->Run({input_tensors_.begin(), input_tensors_.end()}, + output_names_, {}, &output_tensors_); + if (!status.ok()) { + Invalidate("Failed to invoke interpreter"); + } +} + +} // namespace testing +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tf_driver.h b/tensorflow/contrib/lite/testing/tf_driver.h new file mode 100644 index 0000000000000000000000000000000000000000..b766f85c4ddee9fb7b1513c264d4159e694770ca --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver.h @@ -0,0 +1,75 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ + +#include +#include + +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/test_runner.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session.h" + +namespace tflite { +namespace testing { + +// A test runner that feeds inputs into Tensorflow and generates outputs. +class TfDriver : public TestRunner { + public: + explicit TfDriver(const std::vector& input_layer, + const std::vector& input_layer_type, + const std::vector& input_layer_shape, + const std::vector& output_layer); + ~TfDriver() override {} + + void LoadModel(const string& bin_file_path) override; + void SetInput(int id, const string& csv_values) override; + void Invoke() override; + string ReadOutput(int id) override; + + const std::vector& GetInputs() override { return input_ids_; } + const std::vector& GetOutputs() override { return output_ids_; } + void ReshapeTensor(int id, const string& csv_values) override; + // Note: ResetTensor only works for input tensor. + void ResetTensor(int id) override; + + // no-op. SetInput will overwrite existing data . + void AllocateTensors() override {} + // no-op. Tf driver is not supposed to check the results. + void SetExpectation(int id, const string& csv_values) override {} + // tf driver is not supposed to check the results. + bool CheckResults() override { return false; } + + private: + std::unique_ptr session_; + std::vector input_ids_; + std::vector input_names_; + std::vector> input_shapes_; + std::vector input_types_; + std::unordered_map input_tensors_; + + std::vector output_ids_; + std::vector output_names_; + std::vector<::tensorflow::Tensor> output_tensors_; +}; + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TF_DRIVER_H_ diff --git a/tensorflow/contrib/lite/testing/tf_driver_test.cc b/tensorflow/contrib/lite/testing/tf_driver_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0faa4676adc3e846ad398bb203b77b99a2ba360 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tf_driver_test.cc @@ -0,0 +1,56 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/testing/tf_driver.h" + +#include +#include + +namespace tflite { +namespace testing { +namespace { + +using ::testing::ElementsAre; + +TEST(TfDriverTest, SimpleTest) { + std::unique_ptr runner( + new TfDriver({"a", "b", "c", "d"}, {"float", "float", "float", "float"}, + {"1,8,8,3", "1,8,8,3", "1,8,8,3", "1,8,8,3"}, {"x", "y"})); + + runner->LoadModel( + "third_party/tensorflow/contrib/lite/testdata/multi_add.pb"); + EXPECT_TRUE(runner->IsValid()) << runner->GetErrorMessage(); + + ASSERT_THAT(runner->GetInputs(), ElementsAre(0, 1, 2, 3)); + ASSERT_THAT(runner->GetOutputs(), ElementsAre(0, 1)); + + for (int i : {0, 1, 2, 3}) { + runner->ReshapeTensor(i, "1,2,2,1"); + } + ASSERT_TRUE(runner->IsValid()); + + runner->SetInput(0, "0.1,0.2,0.3,0.4"); + runner->SetInput(1, "0.001,0.002,0.003,0.004"); + runner->SetInput(2, "0.001,0.002,0.003,0.004"); + runner->SetInput(3, "0.01,0.02,0.03,0.04"); + runner->ResetTensor(2); + runner->Invoke(); + + ASSERT_EQ(runner->ReadOutput(0), "0.101,0.202,0.303,0.404"); + ASSERT_EQ(runner->ReadOutput(1), "0.011,0.022,0.033,0.044"); +} + +} // namespace +} // namespace testing +} // namespace tflite diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc similarity index 56% rename from tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc rename to tensorflow/contrib/lite/testing/tflite_diff_example_test.cc index 62bb87f2b03dceb7d6f73df6eef26f5e4b31607f..3817e68111dbaaf2a38ceff9fbc38f30f303cb5f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.cc +++ b/tensorflow/contrib/lite/testing/tflite_diff_example_test.cc @@ -13,25 +13,16 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef TF_XLA_HAS_AVX -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x) { - return Eigen::internal::plog(x); +#include "tensorflow/contrib/lite/testing/tflite_diff_flags.h" +#include "tensorflow/contrib/lite/testing/tflite_diff_util.h" + +int main(int argc, char** argv) { + ::tflite::testing::DiffOptions options = + ::tflite::testing::ParseTfliteDiffFlags(&argc, argv); + for (int i = 0; i < 100; i++) { + if (!tflite::testing::RunDiffTest(options)) { + return 1; + } + } + return 0; } -#endif // TF_XLA_HAS_AVX - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kLogV8F32AVXSymbolName = "__xla_cpu_runtime_LogV8F32AVX"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/contrib/lite/testing/tflite_diff_flags.h b/tensorflow/contrib/lite/testing/tflite_diff_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..5f1129d501b7235f1202b704cf36904e07b8720e --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_diff_flags.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ + +#include "tensorflow/contrib/lite/testing/split.h" +#include "tensorflow/contrib/lite/testing/tflite_diff_util.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tflite { +namespace testing { + +DiffOptions ParseTfliteDiffFlags(int* argc, char** argv) { + struct { + string tensorflow_model; + string tflite_model; + string input_layer; + string input_layer_type; + string input_layer_shape; + string output_layer; + } values; + + std::vector flags = { + tensorflow::Flag("tensorflow_model", &values.tensorflow_model, + "Path of tensorflow model."), + tensorflow::Flag("tflite_model", &values.tflite_model, + "Path of tensorflow lite model."), + tensorflow::Flag("input_layer", &values.input_layer, + "Names of input tensors, separated by comma. Example: " + "input_1,input_2"), + tensorflow::Flag("input_layer_type", &values.input_layer_type, + "Data types of input tensors, separated by comma. " + "Example: float,int"), + tensorflow::Flag( + "input_layer_shape", &values.input_layer_shape, + "Shapes of input tensors, separated by colon. Example: 1,3,4,1:2"), + tensorflow::Flag("output_layer", &values.output_layer, + "Names of output tensors, separated by comma. Example " + "output_1,output_2"), + }; + + bool success = tensorflow::Flags::Parse(argc, argv, flags); + if (!success || (*argc == 2 && !strcmp(argv[1], "--helpfull"))) { + fprintf(stderr, "%s", tensorflow::Flags::Usage(argv[0], flags).c_str()); + } + + return {values.tensorflow_model, + values.tflite_model, + Split(values.input_layer, ","), + Split(values.input_layer_type, ","), + Split(values.input_layer_shape, ":"), + Split(values.output_layer, ",")}; +} + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_FLAGS_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.cc b/tensorflow/contrib/lite/testing/tflite_diff_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ef4e1f66c7d31c746c18d63495e760585d4af9e --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.cc @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/testing/generate_testspec.h" +#include "tensorflow/contrib/lite/testing/parse_testdata.h" +#include "tensorflow/contrib/lite/testing/tflite_diff_util.h" +#include "tensorflow/contrib/lite/testing/tflite_driver.h" + +namespace tflite { +namespace testing { + +bool RunDiffTest(const DiffOptions& options) { + std::stringstream tflite_stream; + GenerateTestSpecFromTensorflowModel( + tflite_stream, options.tensorflow_model, options.tflite_model, + options.input_layer, options.input_layer_type, options.input_layer_shape, + options.output_layer); + TfLiteDriver tflite_driver(/*use_nnapi=*/true); + tflite_driver.LoadModel(options.tflite_model); + std::cout << tflite_stream.str(); + return tflite::testing::ParseAndRunTests(&tflite_stream, &tflite_driver); +} +} // namespace testing + +} // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_diff_util.h b/tensorflow/contrib/lite/testing/tflite_diff_util.h new file mode 100644 index 0000000000000000000000000000000000000000..326fa6c3e28000dee9b6eb9cc5b3a6c5c87e28d0 --- /dev/null +++ b/tensorflow/contrib/lite/testing/tflite_diff_util.h @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_UTIL_H_ + +#include + +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace testing { + +// Configurations to run Tflite diff test. +struct DiffOptions { + // Path of tensorflow model. + string tensorflow_model; + // Path of tensorflow lite model. + string tflite_model; + // Names of input tensors. + // Example: input_1,input_2 + std::vector input_layer; + // Data types of input tensors. + // Example: float,int + std::vector input_layer_type; + // Shapes of input tensors, separated by comma. + // Example: 1,3,4,1 + std::vector input_layer_shape; + // Names of output tensors. + // Example output_1,output_2 + std::vector output_layer; +}; + +// Run a single TensorFLow Lite diff test with a given options. +bool RunDiffTest(const DiffOptions& options); + +} // namespace testing +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_TESTING_TFLITE_DIFF_UTIL_H_ diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index bae639ea95318a16c963269de5e55afcb681d4c5..613223f3d4ff212cb8672494243b2d7a1d06b3db 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -106,8 +106,8 @@ class TfLiteDriver::Expectation { if (error_is_large) { good_output = false; if (verbose) { - std::cerr << " index " << i << ": " << reference - << " != " << computed << std::endl; + std::cerr << " index " << i << ": got " << computed + << ", but expected " << reference << std::endl; } } } @@ -203,6 +203,10 @@ void TfLiteDriver::SetInput(int id, const string& csv_values) { void TfLiteDriver::SetExpectation(int id, const string& csv_values) { if (!IsValid()) return; auto* tensor = interpreter_->tensor(id); + if (expected_output_.count(id) != 0) { + fprintf(stderr, "Overriden expectation for tensor %d\n", id); + Invalidate("Overriden expectation"); + } expected_output_[id].reset(new Expectation); switch (tensor->type) { case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 25689a9fb42c06fa3f8f2f92064cf59e8c331637..02b7de1534e648734d7bc53154afa42f2ef256b4 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -45,6 +45,7 @@ class TfLiteDriver : public TestRunner { void SetExpectation(int id, const string& csv_values) override; void Invoke() override; bool CheckResults() override; + string ReadOutput(int id) override { return "no-op"; } private: class Expectation; diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 45031de09c75e9dbf5ee34fe31e7c69ad08b10aa..17407f3db27ead984d1cfffc3f0085ac86f5318f 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -186,6 +186,7 @@ cc_library( "graph_transformations/fuse_binary_into_preceding_affine.cc", "graph_transformations/graph_transformations.cc", "graph_transformations/hardcode_min_max.cc", + "graph_transformations/identify_dilated_conv.cc", "graph_transformations/identify_l2_normalization.cc", "graph_transformations/identify_l2_pool.cc", "graph_transformations/identify_lstm.cc", @@ -224,6 +225,7 @@ cc_library( "graph_transformations/resolve_constant_transpose.cc", "graph_transformations/resolve_constant_unary.cc", "graph_transformations/resolve_mean_attributes.cc", + "graph_transformations/resolve_multiply_by_zero.cc", "graph_transformations/resolve_pad_attributes.cc", "graph_transformations/resolve_reorder_axes.cc", "graph_transformations/resolve_reshape_attributes.cc", diff --git a/tensorflow/contrib/lite/toco/args.h b/tensorflow/contrib/lite/toco/args.h index b97a4720a7c4e69f8b69574475d19e0522cfe86d..59a6115920614d38900c0370708324c122384420 100644 --- a/tensorflow/contrib/lite/toco/args.h +++ b/tensorflow/contrib/lite/toco/args.h @@ -229,6 +229,7 @@ struct ParsedTocoFlags { // Deprecated flags Arg input_type; Arg input_types; + Arg debug_disable_recurrent_cell_fusion = Arg(false); Arg drop_control_dependency = Arg(false); }; diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc index c726eb6d8678e2703f5acba8b3d8d740186939f5..c8352741b44cd627ff9edb9c4677b994c4cb9a09 100644 --- a/tensorflow/contrib/lite/toco/dump_graphviz.cc +++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc @@ -142,14 +142,8 @@ NodeProperties GetPropertiesForArray(const Model& model, // Append array shape to the label. auto& array = model.GetArray(array_name); - - if (array.data_type == ArrayDataType::kFloat) { - AppendF(&node_properties.label, "\\nType: float"); - } else if (array.data_type == ArrayDataType::kInt32) { - AppendF(&node_properties.label, "\\nType: int32"); - } else if (array.data_type == ArrayDataType::kUint8) { - AppendF(&node_properties.label, "\\nType: uint8"); - } + AppendF(&node_properties.label, "\\nType: %s", + ArrayDataTypeName(array.data_type)); if (array.has_shape()) { auto& array_shape = array.shape(); @@ -199,12 +193,12 @@ NodeProperties GetPropertiesForArray(const Model& model, } if (array.minmax) { - AppendF(&node_properties.label, "\\nMinMax: [%.3g, %.3g]", + AppendF(&node_properties.label, "\\nMinMax: [%.7g, %.7g]", array.minmax->min, array.minmax->max); } if (array.quantization_params) { - AppendF(&node_properties.label, "\\nQuantization: %.3g * (x - %d)", + AppendF(&node_properties.label, "\\nQuantization: %7g * (x - %d)", array.quantization_params->scale, array.quantization_params->zero_point); } diff --git a/tensorflow/contrib/lite/toco/export_tensorflow.cc b/tensorflow/contrib/lite/toco/export_tensorflow.cc index 70d7a9d4a5b823d6b4f704b194aa279fbd3f6e13..6900468ec6484d5c1896752286a2fa72f4d38c07 100644 --- a/tensorflow/contrib/lite/toco/export_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/export_tensorflow.cc @@ -239,6 +239,7 @@ void ConvertIntTensorConst(const Model& model, const string& name, } void CreateIntTensorConst(const string& name, const std::vector& data, + const std::vector& shape, GraphDef* tensorflow_graph) { if (HasAlreadyExportedConst(name, *tensorflow_graph)) { return; @@ -252,8 +253,13 @@ void CreateIntTensorConst(const string& name, const std::vector& data, for (auto index : data) { tensor->add_int_val(index); } - auto* shape = tensor->mutable_tensor_shape(); - shape->add_dim()->set_size(data.size()); + auto* tensor_shape = tensor->mutable_tensor_shape(); + int num_elements = 1; + for (int size : shape) { + tensor_shape->add_dim()->set_size(size); + num_elements *= size; + } + CHECK_EQ(num_elements, data.size()); } void CreateMatrixShapeTensorConst(const string& name, int rows, int cols, @@ -385,6 +391,84 @@ void ConvertConvOperator(const Model& model, const ConvOperator& src_op, } } +void ConvertDilatedConvOperator(const Model& model, const ConvOperator& src_op, + GraphDef* tensorflow_graph) { + CHECK((src_op.dilation_width_factor > 1) || + (src_op.dilation_height_factor > 1)) + << "Conv operator must have height or width dilation factor > 1. " + "Otherwise, use regular conv op."; + CHECK_EQ(src_op.stride_width, 1) + << "Dilated AND strided convolution is unsupported"; + CHECK_EQ(src_op.stride_height, 1) + << "Dilated AND strided convolution is unsupported"; + + // Emulate dilated convolution with a chain of SpaceToBatchND -> Conv -> + // BatchToSpaceND ops. + + // Compute padding + const auto& input_array = model.GetArray(src_op.inputs[0]); + const auto& input_shape = input_array.shape(); + CHECK_EQ(input_shape.dimensions_count(), 4); + int height_mod_dilation = input_shape.dims(1) % src_op.dilation_height_factor; + int pad_height; + if (height_mod_dilation) { + pad_height = src_op.dilation_height_factor - height_mod_dilation; + } else { + pad_height = 0; + } + int pad_width; + int width_mod_dilation = input_shape.dims(2) % src_op.dilation_width_factor; + if (width_mod_dilation) { + pad_width = src_op.dilation_width_factor - width_mod_dilation; + } else { + pad_width = 0; + } + + // SpaceToBatchND op "collapses" the spatially separated elements together + string stb_output = src_op.outputs[0] + "/dilated_conv_SpaceToBatch"; + auto* stb_op = tensorflow_graph->add_node(); + stb_op->set_op("SpaceToBatchND"); + stb_op->set_name(stb_output); + *stb_op->add_input() = src_op.inputs[0]; + (*stb_op->mutable_attr())["T"].set_type(DT_FLOAT); + string block_shape = src_op.outputs[0] + "/dilated_conv_block_shape"; + CreateIntTensorConst( + block_shape, + {src_op.dilation_height_factor, src_op.dilation_width_factor}, {2}, + tensorflow_graph); + *stb_op->add_input() = block_shape; + (*stb_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); + string stb_paddings = src_op.outputs[0] + "/dilated_conv_paddings"; + CreateIntTensorConst(stb_paddings, {0, pad_height, pad_width, 0}, {2, 2}, + tensorflow_graph); + *stb_op->add_input() = stb_paddings; + (*stb_op->mutable_attr())["Tpaddings"].set_type(DT_INT32); + + // Perform a regular conv on the "collapsed" elements + ConvOperator conv_op; + string conv_output = src_op.outputs[0] + "/dilated_conv_Conv2D"; + conv_op.inputs = src_op.inputs; + conv_op.inputs[0] = stb_output; + conv_op.outputs = {conv_output}; + conv_op.padding.type = src_op.padding.type; + conv_op.stride_width = src_op.stride_width; + conv_op.stride_height = src_op.stride_height; + conv_op.dilation_width_factor = 1; + conv_op.dilation_height_factor = 1; + ConvertConvOperator(model, conv_op, tensorflow_graph); + + // BatchToSpaceND op restores elements to their original layout + auto* bts_op = tensorflow_graph->add_node(); + bts_op->set_op("BatchToSpaceND"); + bts_op->set_name(src_op.outputs[0]); + *bts_op->add_input() = conv_output; + (*bts_op->mutable_attr())["T"].set_type(DT_FLOAT); + *bts_op->add_input() = block_shape; + (*bts_op->mutable_attr())["Tblock_shape"].set_type(DT_INT32); + *bts_op->add_input() = stb_paddings; + (*bts_op->mutable_attr())["Tcrops"].set_type(DT_INT32); +} + void ConvertDepthwiseConvOperator(const Model& model, const DepthwiseConvOperator& src_op, GraphDef* tensorflow_graph) { @@ -520,7 +604,7 @@ void ConvertFullyConnectedOperator(const Model& model, AvailableArrayName(model, matmul_output + "/transpose_weights"); const string transpose_perm = AvailableArrayName(model, transpose_output + "/perm"); - CreateIntTensorConst(transpose_perm, {1, 0}, tensorflow_graph); + CreateIntTensorConst(transpose_perm, {1, 0}, {2}, tensorflow_graph); auto transpose_op = tensorflow_graph->add_node(); transpose_op->set_op("Transpose"); transpose_op->set_name(transpose_output); @@ -720,7 +804,8 @@ void ConvertLogSoftmaxOperator(const Model& model, GraphDef* tensorflow_graph) { string softmax_input; Operator* providing_op = GetOpWithOutput(model, src_op.inputs[0]); - if (providing_op->type == OperatorType::kTensorFlowReshape) { + if (providing_op != nullptr && + providing_op->type == OperatorType::kTensorFlowReshape) { softmax_input = src_op.inputs[0]; } else { // Insert a reshape operator that reduces the dimensions down to the 2 that @@ -1236,8 +1321,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write weights const string weights_output = base + "weights"; CHECK(model.HasArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT])); - const auto& weights_array = - model.GetArray(src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + const string weights_name = WalkUpToConstantArray( + model, src_op.inputs[LstmCellOperator::WEIGHTS_INPUT]); + const auto& weights_array = model.GetArray(weights_name); // Convert 4D FullyConnected weights into 2D matrix const auto& weights_shape = weights_array.shape(); CHECK_EQ(weights_shape.dimensions_count(), 2); @@ -1262,8 +1348,9 @@ void ConvertLstmCellOperator(const Model& model, const LstmCellOperator& src_op, // Write biases const string biases_output = base + "biases"; CHECK(model.HasArray(src_op.inputs[LstmCellOperator::BIASES_INPUT])); - const auto& bias_array = - model.GetArray(src_op.inputs[LstmCellOperator::BIASES_INPUT]); + const string bias_name = WalkUpToConstantArray( + model, src_op.inputs[LstmCellOperator::BIASES_INPUT]); + const auto& bias_array = model.GetArray(bias_name); // TODO(b/62904716) Bias arrays should be 1-D, and used directly. Shape bias_shape_1d = bias_array.shape(); UnextendShape(&bias_shape_1d, 1); @@ -1579,6 +1666,17 @@ void ConvertTensorFlowMaximumOperator(const Model& model, (*sub_op->mutable_attr())["T"].set_type(data_type); } +void ConvertTopKV2Operator(const Model& model, const TopKV2Operator& src_op, + GraphDef* tensorflow_graph) { + auto* topk_op = tensorflow_graph->add_node(); + topk_op->set_op("TOPKV2"); + topk_op->set_name(src_op.outputs[0]); + CHECK_EQ(src_op.inputs.size(), 2); + *topk_op->add_input() = src_op.inputs[0]; + *topk_op->add_input() = src_op.inputs[1]; + (*topk_op->mutable_attr())["sorted"].set_b(true); +} + void ConvertOperator(const Model& model, const Operator& src_op, GraphDef* tensorflow_graph) { if (src_op.fused_activation_function != FusedActivationFunctionType::kNone) { @@ -1587,8 +1685,13 @@ void ConvertOperator(const Model& model, const Operator& src_op, } if (src_op.type == OperatorType::kConv) { - ConvertConvOperator(model, static_cast(src_op), - tensorflow_graph); + const ConvOperator& conv_op = static_cast(src_op); + if ((conv_op.dilation_width_factor != 1) || + (conv_op.dilation_height_factor != 1)) { + return ConvertDilatedConvOperator(model, conv_op, tensorflow_graph); + } else { + ConvertConvOperator(model, conv_op, tensorflow_graph); + } } else if (src_op.type == OperatorType::kDepthwiseConv) { ConvertDepthwiseConvOperator( model, static_cast(src_op), @@ -1727,6 +1830,9 @@ void ConvertOperator(const Model& model, const Operator& src_op, } else if (src_op.type == OperatorType::kArgMax) { ConvertArgMaxOperator(model, static_cast(src_op), tensorflow_graph); + } else if (src_op.type == OperatorType::kTopK_V2) { + ConvertTopKV2Operator(model, static_cast(src_op), + tensorflow_graph); } else if (src_op.type == OperatorType::kTranspose) { ConvertTransposeOperator( model, static_cast(src_op), tensorflow_graph); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 3ab01ae643b26cfe0c7ce30472f693794326b9b3..f2c81ebc81c2928ae60d66bfcd7f643c5412f196 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -128,6 +128,7 @@ DECLARE_GRAPH_TRANSFORMATION(IdentifyLstmCell) DECLARE_GRAPH_TRANSFORMATION(SplitLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(MergeLstmCellInputs) DECLARE_GRAPH_TRANSFORMATION(IdentifyRelu1) +DECLARE_GRAPH_TRANSFORMATION(IdentifyDilatedConv) DECLARE_GRAPH_TRANSFORMATION(MakeInitialDequantizeOperator) DECLARE_GRAPH_TRANSFORMATION(PropagateArrayDataTypes) DECLARE_GRAPH_TRANSFORMATION(PropagateFixedSizes) @@ -174,6 +175,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) +DECLARE_GRAPH_TRANSFORMATION(ResolveMultiplyByZero) DECLARE_GRAPH_TRANSFORMATION(Dequantize) class ResolveReshapeAttributes : public GraphTransformation { diff --git a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc index 1b0be858107b54f5a6ecd2a1cb87c9dbde1c06bb..938d76386d6f315abfe6fe55b133cb4d19014f01 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/hardcode_min_max.cc @@ -125,6 +125,27 @@ bool HardcodeMinMaxForConcatenation(Model* model, Operator* op) { return changed; } +bool HardcodeMinMaxForSplit(Model* model, Operator* op) { + for (const auto& output : op->outputs) { + if (model->GetArray(output).minmax) { + LOG(WARNING) << "Skipping min-max setting for " << LogName(*op) + << " because output " << output << " already has min-max."; + return false; + } + } + // Data is in second input. + auto& input_array = model->GetArray(op->inputs[1]); + if (!input_array.minmax) { + return false; + } else { + for (const auto& output : op->outputs) { + auto& array = model->GetArray(output); + array.GetOrCreateMinMax() = *input_array.minmax; + } + return true; + } +} + // The output of average or max pooling is within the same range as its input. bool HardcodeMinMaxForAverageOrMaxPool(Model* model, Operator* op) { auto& output_array = model->GetArray(op->outputs[0]); @@ -296,6 +317,10 @@ bool HardcodeMinMax::Run(Model* model, std::size_t op_index) { changed = HardcodeMinMaxForConcatenation(model, op); break; + case OperatorType::kTensorFlowSplit: + changed = HardcodeMinMaxForSplit(model, op); + break; + case OperatorType::kAveragePool: case OperatorType::kMaxPool: changed = HardcodeMinMaxForAverageOrMaxPool(model, op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae3301f467de5714230e731b4bab87ddc1637201 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_dilated_conv.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +// A dilated convolution can be emulated with a regular convolution by chaining +// SpaceToBatch and BatchToSpace ops before and after it: +// +// SpaceToBatchND -> Conv2D -> BatchToSpaceND +// +// This method was common before Conv2D fully supported dilated convolution in +// TensorFlow. This transformation detects this "emulation", and replaces it +// with a true dilated convolution, eliminating the SpaceToBatch and +// BatchtoSpace ops. +// +// Detecting this alone would be relatively easy. However, in practice some +// extra ops are used, so we detect the following patterns: +// +// +// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BatchToSpaceND -> BiasAdd +// +// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> Pad -> BatchToSpaceND -> +// BiasAdd +// +// SpaceToBatchND -> Expand -> Conv2D -> Squeeze -> BiasAdd -> BatchToSpaceND +// +// SpaceToBatchND -> Conv2D -> Pad -> BatchToSpaceND -> BiasAdd +// +// SpaceToBatchND -> Conv2D -> BatchToSpaceND -> BiasAdd +// +// +// The Expand/Squeeze combination is used to adapt a 3D array (such as in +// WaveNet) to the 4D arrays that Conv2D requires. Padding and BiasAdd are +// thrown in just for the extra headache. Padding adapts non-conforming input +// sizes, and can be discarded. The bias is necessary, so is kept. + +bool IdentifyDilatedConv::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + auto* stb_op = it->get(); + + // 1. IDENTIFY OPERATORS + // *************************************************************************** + // SpaceToBatch Op. + if (stb_op->type != OperatorType::kSpaceToBatchND) { + return false; + } + if (stb_op->inputs.size() != 3) { + return false; + } + CHECK_EQ(stb_op->outputs.size(), 1); + // Extract the dilation factor from Input[1] of SpaceToBatch + // TODO(mjmatthews): Support 2D dilation factors. + const auto& block_shape_array = model->GetArray(stb_op->inputs[1]); + if (!block_shape_array.buffer) { + return false; + } + CHECK_EQ(block_shape_array.shape().dimensions_count(), 1); + int dilation_factor = + block_shape_array.Array::GetBuffer().data[0]; + + // Expand Op + auto* post_stb_op = GetOpWithInput(*model, stb_op->outputs[0]); + if (!post_stb_op) { + return false; + } + bool has_expand_op = false; + if (post_stb_op->type == OperatorType::kExpandDims) { + has_expand_op = true; + CHECK_EQ(post_stb_op->inputs.size(), 2); + CHECK_EQ(post_stb_op->outputs.size(), 1); + } + + // Conv Op + ConvOperator* conv_op = dynamic_cast( + has_expand_op ? GetOpWithInput(*model, post_stb_op->outputs[0]) + : GetOpWithInput(*model, stb_op->outputs[0])); + if (!conv_op || conv_op->type != OperatorType::kConv) { + return false; + } + if (conv_op->inputs.size() != 2) { + // The conv op must only have weights, no bias. + return false; + } + CHECK_EQ(conv_op->outputs.size(), 1); + + // Squeeze Op + auto* post_conv_op = GetOpWithInput(*model, conv_op->outputs[0]); + if (!post_conv_op) { + return false; + } + if (has_expand_op) { + if (post_conv_op->type != OperatorType::kSqueeze) { + // If an expand op was used, the post-conv op must be a squeeze op + return false; + } + CHECK_EQ(post_conv_op->inputs.size(), 1); + CHECK_EQ(post_conv_op->outputs.size(), 1); + } + + // Pad Op + const auto* pad_op = has_expand_op + ? GetOpWithInput(*model, post_conv_op->outputs[0]) + : GetOpWithInput(*model, conv_op->outputs[0]); + bool has_pad_op = false; + if (pad_op->type == OperatorType::kPad) { + has_pad_op = true; + CHECK_EQ(pad_op->inputs.size(), 2); + CHECK_EQ(pad_op->outputs.size(), 1); + } + // TODO(mjmatthews): Perform validity checking on padding dimensions. + + // Pre-BatchToSpace Bias Op + auto* next_op = has_pad_op + ? GetOpWithInput(*model, pad_op->outputs[0]) + : has_expand_op + ? GetOpWithInput(*model, post_conv_op->outputs[0]) + : GetOpWithInput(*model, conv_op->outputs[0]); + bool has_bias_before_bts = false; + if (next_op->type == OperatorType::kAdd) { + has_bias_before_bts = true; + } + auto final_op = GetOpWithInput(*model, next_op->outputs[0]); + + // BatchToSpace Op + const auto* bts_op = has_bias_before_bts ? final_op : next_op; + if (bts_op->type != OperatorType::kBatchToSpaceND) { + return false; + } + CHECK_EQ(bts_op->inputs.size(), 3); + CHECK_EQ(bts_op->outputs.size(), 1); + + // Post-BatchToSpace Bias Op + Operator* bias_add_op = !has_bias_before_bts ? final_op : next_op; + if (bias_add_op->type != OperatorType::kAdd) { + // Bias op is required before or after BatchToSpace + return false; + } + CHECK_EQ(bias_add_op->inputs.size(), 2); + CHECK_EQ(bias_add_op->outputs.size(), 1); + + LOG(INFO) << "Identified sub-network emulating dilated convolution."; + + // 2. RE-WIRE OPERATORS + // *************************************************************************** + // Re-use the existing Conv2D op. + conv_op->dilation_width_factor = dilation_factor; + conv_op->dilation_height_factor = dilation_factor; + conv_op->padding.type = PaddingType::kSame; + + // Rewire the ops to bypass SpaceToBatch, BatchToSpace, and Pad. + bias_add_op->outputs[0] = final_op->outputs[0]; + if (has_expand_op) { + bias_add_op->inputs[0] = post_conv_op->outputs[0]; + post_conv_op->inputs[0] = conv_op->outputs[0]; + conv_op->inputs[0] = post_stb_op->outputs[0]; + post_stb_op->inputs[0] = stb_op->inputs[0]; + } else { + bias_add_op->inputs[0] = conv_op->outputs[0]; + conv_op->inputs[0] = stb_op->inputs[0]; + } + // TODO(mjmatthews): Connect bias directly into the Conv2D? + + // 3. DELETE LEFTOVER OPERATORS + // *************************************************************************** + // Order is important. Delete the output array first, then the op, then it's + // redundant inputs. + // BatchToSpace Op + DeleteArrayIfUnused(bts_op->outputs[0], model); + std::vector bts_op_inputs = bts_op->inputs; + model->operators.erase(FindOp(*model, bts_op)); + DeleteArrayIfUnused(bts_op_inputs[1], model); + DeleteArrayIfUnused(bts_op_inputs[2], model); + + // Pad Op if present + if (has_pad_op) { + DeleteArrayIfUnused(pad_op->outputs[0], model); + std::vector pad_op_inputs = pad_op->inputs; + model->operators.erase(FindOp(*model, pad_op)); + DeleteArrayIfUnused(pad_op_inputs[1], model); + } + + // SpaceToBatch Op + DeleteArrayIfUnused(stb_op->outputs[0], model); + std::vector stb_op_inputs = stb_op->inputs; + model->operators.erase(FindOp(*model, stb_op)); + DeleteArrayIfUnused(stb_op_inputs[1], model); + DeleteArrayIfUnused(stb_op_inputs[2], model); + + LOG(INFO) << "Replaced with Dilated Conv2D op outputting \"" + << conv_op->outputs[0] << "\"."; + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc index d36e95060937d6af0789766bcb29ae70cef4569d..de6d8889fb4ccdb56e9639ab0dd7d093bfa4b908 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/identify_relu1.cc @@ -57,45 +57,60 @@ int GetSingleScalarInputIndexOfBinaryOp(Model* model, const Operator* op, } // namespace bool IdentifyRelu1::Run(Model* model, std::size_t op_index) { - const auto maximum_it = model->operators.begin() + op_index; - const auto* maximum_op = maximum_it->get(); - if (maximum_op->type != OperatorType::kTensorFlowMaximum) { + // Follow sequences of min+max and max+min. First get the leading op. + const auto op_it = model->operators.begin() + op_index; + const auto* op_0 = op_it->get(); + if (op_0->type != OperatorType::kTensorFlowMinimum && + op_0->type != OperatorType::kTensorFlowMaximum) { return false; } - CHECK_EQ(maximum_op->inputs.size(), 2); - if (maximum_op->outputs.size() != 1) { - return false; - } - int scalar_input_index = - GetSingleScalarInputIndexOfBinaryOp(model, maximum_op, -1.0f); - if (scalar_input_index == -1) { + + // Get the paired op and ensure it's the counter to the first. + const auto* op_1 = GetOpWithInput(*model, op_0->outputs[0]); + if (!op_1 || + (op_1->type != OperatorType::kTensorFlowMinimum && + op_1->type != OperatorType::kTensorFlowMaximum) || + op_0->type == op_1->type) { return false; } - const auto* minimum_op = GetOpWithInput(*model, maximum_op->outputs[0]); - if (!minimum_op || minimum_op->type != OperatorType::kTensorFlowMinimum) { + + const auto* min_op = + op_0->type == OperatorType::kTensorFlowMinimum ? op_0 : op_1; + const auto* max_op = + op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1; + + CHECK_EQ(min_op->inputs.size(), 2); + CHECK_EQ(max_op->inputs.size(), 2); + if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) { return false; } - if (GetSingleScalarInputIndexOfBinaryOp(model, minimum_op, 1.0f) == -1) { + + // Get the original input to the min+max pair. + int min_scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, min_op, 1.0f); + int max_scalar_input_index = + GetSingleScalarInputIndexOfBinaryOp(model, max_op, -1.0f); + if (min_scalar_input_index == -1 || max_scalar_input_index == -1) { return false; } - CHECK_EQ(minimum_op->inputs.size(), 2); + int op_0_scalar_input_index = + op_0 == min_op ? min_scalar_input_index : max_scalar_input_index; - // Create and emplace Relu1 node + // Create and emplace Relu1 node. auto* relu1_op = new Relu1Operator; - relu1_op->inputs = {maximum_op->inputs[!scalar_input_index]}; - relu1_op->outputs = minimum_op->outputs; - model->operators.emplace(maximum_it, relu1_op); + relu1_op->inputs = {op_0->inputs[!op_0_scalar_input_index]}; + relu1_op->outputs = op_1->outputs; + model->operators.emplace(op_it, relu1_op); AddMessageF("Creating %s replacing equivalent subgraph", LogName(*relu1_op)); - // Erase Maximum scalar input & operator - model->EraseArray(maximum_op->inputs[scalar_input_index]); - model->operators.erase(FindOperator(model, maximum_op)); - - // Erase Minimum inputs & operator - model->EraseArray(minimum_op->inputs[0]); - model->EraseArray(minimum_op->inputs[1]); - model->operators.erase(FindOperator(model, minimum_op)); + // Erase op scalar inputs & operators. Note that we preserve the non-scalar + // input to the first op as that's been redirected to the relu1_op. + DeleteArrayIfUsedOnce(op_0->inputs[op_0_scalar_input_index], model); + DeleteArrayIfUsedOnce(op_1->inputs[0], model); + DeleteArrayIfUsedOnce(op_1->inputs[1], model); + model->operators.erase(FindOperator(model, op_0)); + model->operators.erase(FindOperator(model, op_1)); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 3de251ed70fba25b6320b4bbbff540c2d107598b..0e2e5ecf30053103492337685d85a2aacf832caf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -31,17 +31,22 @@ namespace { void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, int kheight, int stride_width, int stride_height, + int dilation_width_factor, int dilation_height_factor, PaddingType padding_type, Shape* output_shape, FixedPadding* fixed_padding) { const int input_width = input_shape.dims(2); const int input_height = input_shape.dims(1); const int batch = input_shape.dims(0); + int dilated_kwidth = dilation_width_factor * (kwidth - 1) + 1; + int dilated_kheight = dilation_height_factor * (kheight - 1) + 1; + int output_height = 0; int output_width = 0; if (padding_type == PaddingType::kValid) { - output_height = (input_height + stride_height - kheight) / stride_height; - output_width = (input_width + stride_width - kwidth) / stride_width; + output_height = + (input_height + stride_height - dilated_kheight) / stride_height; + output_width = (input_width + stride_width - dilated_kwidth) / stride_width; } else if (padding_type == PaddingType::kSame) { output_height = (input_height + stride_height - 1) / stride_height; output_width = (input_width + stride_width - 1) / stride_width; @@ -49,10 +54,12 @@ void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth, LOG(FATAL) << "Only supporting SAME or VALID padding"; } - fixed_padding->height = std::max( - 0, ((output_height - 1) * stride_height + kheight - input_height) / 2); + fixed_padding->height = std::max(0, ((output_height - 1) * stride_height + + dilated_kheight - input_height) / + 2); fixed_padding->width = std::max( - 0, ((output_width - 1) * stride_width + kwidth - input_width) / 2); + 0, + ((output_width - 1) * stride_width + dilated_kwidth - input_width) / 2); // Actually had to debug a situation where those were negative due to bad // propagation of placeholder -1 sizes in TensorFlowReshape. @@ -166,7 +173,8 @@ void ProcessConvOperator(Model* model, ConvOperator* op) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, - op->stride_height, op->padding.type, + op->stride_height, op->dilation_width_factor, + op->dilation_height_factor, op->padding.type, output_array.mutable_shape(), &op->padding.GetOrCreateFixedPadding()); CHECK_EQ(output_array.shape().dimensions_count(), 4); @@ -222,7 +230,7 @@ void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) { const int kheight = weights_shape.dims(1); const int kwidth = weights_shape.dims(2); ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width, - op->stride_height, op->padding.type, + op->stride_height, 1, 1, op->padding.type, model->GetArray(output_name).mutable_shape(), &op->padding.GetOrCreateFixedPadding()); } @@ -650,15 +658,34 @@ void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) { } const Shape& input_shape = input_array.shape(); - // This code is slightly suspect. The TensorFlow docs say that the axis - // selection defaults to 0, but we are splitting across the final axis. - const int input_dims_count = input_shape.dimensions_count(); - const int input_depth = input_shape.dims(input_dims_count - 1); - CHECK_EQ(input_depth % op->num_split, 0); - const int split_depth = input_depth / op->num_split; + // Yield until axis is constant. + if (!IsConstantParameterArray(*model, op->inputs[0])) { + return; + } + + const auto& axis_array = model->GetArray(op->inputs[0]); + + // Yield until axis dims have been resolved. + if (!axis_array.has_shape()) { + return; + } + + CHECK(axis_array.data_type == ArrayDataType::kInt32) + << "Axis array must be int32."; + CHECK_EQ(RequiredBufferSizeForShape(axis_array.shape()), 1) + << "Axis array must be scalar."; + + int axis = axis_array.GetBuffer().data[0]; + if (axis < 0) { + axis += input_shape.dimensions_count(); + } + + const int split_dim = input_shape.dims(axis); + CHECK_EQ(split_dim % op->num_split, 0); + const int split_depth = split_dim / op->num_split; Shape output_shape = input_shape; - (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth; + (*output_shape.mutable_dims())[axis] = split_depth; CHECK_EQ(op->outputs.size(), op->num_split); for (const auto& output : op->outputs) { @@ -678,7 +705,7 @@ void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) { const string& output_name = op->outputs[0]; const int output_depth = input_shape.dims(3); ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, - op->stride_width, op->stride_height, op->padding.type, + op->stride_width, op->stride_height, 1, 1, op->padding.type, model->GetArray(output_name).mutable_shape(), &op->padding.GetOrCreateFixedPadding()); } @@ -695,7 +722,7 @@ void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) { const string& output_name = op->outputs[0]; const int output_depth = input_shape.dims(3); ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, - op->stride_width, op->stride_height, op->padding.type, + op->stride_width, op->stride_height, 1, 1, op->padding.type, model->GetArray(output_name).mutable_shape(), &op->padding.GetOrCreateFixedPadding()); } @@ -714,7 +741,7 @@ void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) { const string& output_name = op->outputs[0]; const int output_depth = input_shape.dims(3); ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight, - op->stride_width, op->stride_height, op->padding.type, + op->stride_width, op->stride_height, 1, 1, op->padding.type, model->GetArray(output_name).mutable_shape(), &op->padding.GetOrCreateFixedPadding()); } @@ -963,6 +990,43 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { } } +void ProcessTopkV2Operator(Model* model, TopKV2Operator* op) { + const auto& input_values = model->GetArray(op->inputs[0]); + const auto& input_k = model->GetArray(op->inputs[1]); + auto& output_indexes = model->GetArray(op->outputs[0]); + auto& output_values = model->GetArray(op->outputs[1]); + + // Bail if we already know the output shape. + if (output_indexes.has_shape()) { + QCHECK(output_values.has_shape()); + return; + } + + // Yield until input dims have been resolved. + if (!input_values.has_shape()) { + return; + } + + const auto& input_values_shape = input_values.shape(); + auto output_indexes_dims = output_indexes.mutable_shape()->mutable_dims(); + auto output_values_dims = output_values.mutable_shape()->mutable_dims(); + for (int dim = 0; dim < input_values_shape.dimensions_count() - 1; dim++) { + output_indexes_dims->push_back(input_values_shape.dims(dim)); + output_values_dims->push_back(input_values_shape.dims(dim)); + } + // If the value is initialized, we can specify the last dimension, otherwise + // unknown. + if (input_k.buffer) { + const int32_t k_value = input_k.GetBuffer().data[0]; + output_indexes_dims->push_back(k_value); + output_values_dims->push_back(k_value); + + } else { + output_indexes_dims->push_back(0); + output_values_dims->push_back(0); + } +} + void ProcessPadOperator(Model* model, PadOperator* op) { CHECK_EQ(op->inputs.size(), 2); CHECK_EQ(op->outputs.size(), 1); @@ -1308,12 +1372,15 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) { case OperatorType::kTensorFlowAssert: case OperatorType::kCast: case OperatorType::kFloor: + case OperatorType::kExp: ProcessSimpleOperator(model, op); break; case OperatorType::kGather: ProcessGatherOperator(model, static_cast(op)); break; - + case OperatorType::kTopK_V2: + ProcessTopkV2Operator(model, static_cast(op)); + break; case OperatorType::kAdd: case OperatorType::kSub: case OperatorType::kMul: diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc index d7f804ee432598cafe6b6c05d03219aa7d2783fa..77316751bc2642a0c974d16f694aeebe1cd53a9f 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc @@ -100,7 +100,13 @@ void QuantizeArray(GraphTransformation* transformation, Model* model, void QuantizeArray(GraphTransformation* transformation, Model* model, const string& name, ArrayDataType quantized_data_type, const QuantizationParams& quantization_params) { - switch (quantized_data_type) { + ArrayDataType adjusted_data_type = quantized_data_type; + auto& array = model->GetArray(name); + if (array.final_data_type == ArrayDataType::kInt16) { + adjusted_data_type = array.final_data_type; + } + + switch (adjusted_data_type) { case ArrayDataType::kUint8: return QuantizeArray(transformation, model, name, quantization_params); @@ -166,6 +172,60 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) { "proceed with quantization."; } +struct QuantizationPoints { + int64 min_value; + int64 max_value; + int64 central_value; +}; + +template +QuantizationPoints GetQuantizationPoints() { + QuantizationPoints qp; + using Integer = DataType; + qp.min_value = std::numeric_limits::min(); + qp.max_value = std::numeric_limits::max(); + // eg [-128,127]... + qp.central_value = (qp.min_value / 2 + // -128 -> -64. + (qp.max_value - 1) / 2 + // 127 -> 63. + 1); + return qp; +} + +QuantizationPoints GetQuantizationPoints(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kUint8: + return GetQuantizationPoints(); + case ArrayDataType::kInt16: + return GetQuantizationPoints(); + case ArrayDataType::kInt32: + return GetQuantizationPoints(); + default: + LOG(FATAL) << "Unhandled case."; + } +} + +ArrayDataType GetQuantizedDataType(const Array& array, + ArrayDataType default_type) { + switch (array.final_data_type) { + case ArrayDataType::kInt8: + case ArrayDataType::kUint8: + case ArrayDataType::kInt16: + case ArrayDataType::kUint16: + case ArrayDataType::kInt32: + case ArrayDataType::kUint32: + case ArrayDataType::kInt64: + case ArrayDataType::kUint64: + return array.final_data_type; + case ArrayDataType::kFloat: + case ArrayDataType::kNone: + return default_type; + default: + LOG(FATAL) << "Unhandled final quantization type " + << static_cast(array.final_data_type); + return default_type; + } +} + bool ChooseQuantizationForOperatorInput( GraphTransformation* transformation, Model* model, const Operator& op, std::size_t input_index, ArrayDataType* quantized_data_type, @@ -212,7 +272,7 @@ bool ChooseQuantizationForOperatorInput( const auto input_weights_scale = input_weights.quantization_params->scale; quantization_params->scale = input_activations_scale * input_weights_scale; quantization_params->zero_point = 0; - *quantized_data_type = ArrayDataType::kInt32; + *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kInt32); transformation->AddMessageF( "Input array %s is a bias vector. Choosing quantization params " "accordingly.", @@ -233,14 +293,14 @@ bool ChooseQuantizationForOperatorInput( GetQuantizationParamsFromMinMax(model->flags, minmax, quantization_params); + *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); transformation->AddMessageF( "For input array %s with min=%g" ", max=%g" - ", chose to quantize as uint8 with zero_point=%d" + ", chose to quantize as %s with zero_point=%d" ", scale=%g", - input, minmax.min, minmax.max, quantization_params->zero_point, - quantization_params->scale); - *quantized_data_type = ArrayDataType::kUint8; + input, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type), + quantization_params->zero_point, quantization_params->scale); return true; } @@ -262,16 +322,18 @@ bool IsExactlyRepresentable(double real_value, ArrayDataType data_type, return true; } +// Quantized data type is preset to the type of the input before this function. bool ChooseHardcodedQuantizationForOperatorOutput( - const Operator& op, ArrayDataType* quantized_data_type, + const Operator& op, const Array& array, ArrayDataType* quantized_data_type, QuantizationParams* quantization_params) { if (op.type == OperatorType::kL2Normalization) { // L2Normalization has range: [-1, 1]. // 0 should be exactly representable, as values will typically be centered // around 0, with many values near 0. - *quantized_data_type = ArrayDataType::kUint8; - quantization_params->zero_point = 128; - quantization_params->scale = 1. / 128.; + *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type); + const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type); + quantization_params->zero_point = qp.central_value; + quantization_params->scale = 1. / (qp.central_value - qp.min_value); CHECK( IsExactlyRepresentable(0., *quantized_data_type, *quantization_params)); return true; @@ -284,18 +346,20 @@ bool ChooseHardcodedQuantizationForOperatorOutput( // will typically exploit the symmetry logistic(-x) = 1 - logistic(x), and // the glueing of the two halves of the graph will only be seamless if we // are accurately representing logistic(0) == 0.5. - *quantized_data_type = ArrayDataType::kUint8; + *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type); + const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type); quantization_params->zero_point = 0; - quantization_params->scale = 1. / 256.; + quantization_params->scale = 1. / (qp.max_value + 1); CHECK(IsExactlyRepresentable(0.5, *quantized_data_type, *quantization_params)); return true; } if (op.type == OperatorType::kTanh) { // Tanh has the range: [-1, 1]. - *quantized_data_type = ArrayDataType::kUint8; - quantization_params->zero_point = 128; - quantization_params->scale = 1. / 128.; + *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type); + const QuantizationPoints qp = GetQuantizationPoints(*quantized_data_type); + quantization_params->zero_point = qp.central_value; + quantization_params->scale = 1. / (qp.central_value - qp.min_value); // 0 should be exactly representable, as values will typically be centered // around 0, with many values near 0. CHECK( @@ -314,8 +378,9 @@ bool ChooseQuantizationForOperatorOutput( if (array.data_type != ArrayDataType::kFloat) { return false; } - if (ChooseHardcodedQuantizationForOperatorOutput(op, quantized_data_type, - quantization_params)) { + *quantized_data_type = model->GetArray(op.inputs[0]).data_type; + if (ChooseHardcodedQuantizationForOperatorOutput( + op, array, quantized_data_type, quantization_params)) { transformation->AddMessageF( "Output array %s is produced by a %s operator. Choosing fixed " "quantization params accordingly.", @@ -323,12 +388,21 @@ bool ChooseQuantizationForOperatorOutput( return true; } if ((op.type == OperatorType::kDepthToSpace) || - (op.type == OperatorType::kSpaceToDepth)) { - // DepthToSpace and SpaceToDepth should preserve the quantization parameters - // of the input array, as these are simple reshape operations. - const auto& input_quantization_params = - model->GetArray(op.inputs[0]).GetQuantizationParams(); - *quantized_data_type = ArrayDataType::kUint8; + (op.type == OperatorType::kSpaceToDepth) || + (op.type == OperatorType::kTensorFlowReshape) || + (op.type == OperatorType::kTensorFlowSplit) || + (op.type == OperatorType::kConcatenation)) { + int data_input_index = 0; + if (op.type == OperatorType::kTensorFlowSplit) { + data_input_index = 1; + } + // Copying and rearrangement ops should preserve the quantization parameters + // of the input array. + const auto& input_array = model->GetArray(op.inputs[data_input_index]); + const auto& input_quantization_params = input_array.GetQuantizationParams(); + *quantized_data_type = + GetQuantizedDataType(input_array, ArrayDataType::kUint8); + *quantized_data_type = GetQuantizedDataType(array, *quantized_data_type); quantization_params->zero_point = input_quantization_params.zero_point; quantization_params->scale = input_quantization_params.scale; @@ -350,13 +424,13 @@ bool ChooseQuantizationForOperatorOutput( } GetQuantizationParamsFromMinMax(model->flags, minmax, quantization_params); - *quantized_data_type = ArrayDataType::kUint8; + *quantized_data_type = GetQuantizedDataType(array, ArrayDataType::kUint8); transformation->AddMessageF( "For output array %s with min=%g, max=%g" - ", chose to quantize as uint8 with zero_point=%d" + ", chose to quantize as %s with zero_point=%d" ", scale=%g", - output, minmax.min, minmax.max, quantization_params->zero_point, - quantization_params->scale); + output, minmax.min, minmax.max, ArrayDataTypeName(*quantized_data_type), + quantization_params->zero_point, quantization_params->scale); return true; } diff --git a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc index cabbc4d313be3069053f056eb0de45c37ba2e7a4..30a005c789bb12e880e8e4534088d99ebacba84a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/reorder_activation_functions.cc @@ -62,6 +62,20 @@ bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { return false; } + // If the ac_op was originally producing an output_array we can't reorder as + // otherwise the output array would change. It'd be nice to still be able to + // reorder but if code is relying on the fetch names instead of array indices + // this won't work. + for (int i = 0; i < model->flags.output_arrays_size(); ++i) { + if (model->flags.output_arrays(i) == ac_op->outputs[0]) { + AddMessageF( + "Not exchanging activation function with %s to preserve output array " + "name %s", + LogName(*exchange_op), ac_op->outputs[0]); + return false; + } + } + // Rewire by changing inputs, including all consumers. Operator* consumer = GetFirstOpWithInput(*model, ac_op_output); while (consumer) { @@ -75,6 +89,10 @@ bool ReorderActivationFunctions::Run(Model* model, std::size_t op_index) { ac_op->inputs[0] = exchange_op_input; exchange_op->inputs[0] = ac_op_output; + // Clear shapes; this will allow shape propagation to fix the sizes for us. + model->GetOrCreateArray(ac_op->outputs[0]).clear_shape(); + model->GetOrCreateArray(exchange_op->outputs[0]).clear_shape(); + // Finally, reorder operators. Note that this only works when there are no // other direct descendents of the exchange_op. ac_op.swap(exchange_op); diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc index 1cd2aff28c68eaba4e9b18d8e2c2803834328696..f227554bc505efe6a758fdd9894fee43f2500641 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_unary.cc @@ -139,14 +139,13 @@ bool ResolveConstantUnaryOperator::Run(Model* model, std::size_t op_index) { output_buffer_size * sizeof(output_float_data[0])); } else if (unary_op->type == OperatorType::kTensorFlowSum) { // At the moment only full reduction across all dimensions is supported. - for (int i = 0; i < output_dims_count; i++) { - CHECK_EQ(output_shape.dims(i), 1); - } float sum = 0.f; for (int i = 0; i < input_buffer_size; i++) { sum += (*input_float_data)[i]; } - output_float_data[0] = sum; + for (int i = 0; i < output_buffer_size; ++i) { + output_float_data[i] = sum; + } } else if (unary_op->type == OperatorType::kTensorFlowMin) { // At the moment only full reduction across all dimensions is supported. // TODO(starka): Output should not be padded. diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc new file mode 100644 index 0000000000000000000000000000000000000000..37beb41dfc5904fc6ace79ebea2420d2ab92fbfb --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc @@ -0,0 +1,152 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" + +namespace toco { + +namespace { + +template +bool AreAllBufferElementsZero(const std::vector& buffer_data) { + for (auto x : buffer_data) { + if (x != 0) { + return false; + } + } + return true; +} + +template +void FillArrayWithZeros(Array* array) { + CHECK(array->data_type == Type); + std::vector>& data = array->GetMutableBuffer().data; + data.resize(RequiredBufferSizeForShape(array->shape())); + for (size_t i = 0; i < data.size(); i++) { + data[i] = 0; + } +} + +} // namespace + +// Removes a multiplication by array of constant zeros by making the output +// array an array of constant zeros and removing the input arrays if they are no +// longer needed. +bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) { + const auto mul_it = model->operators.begin() + op_index; + auto* mul_op = mul_it->get(); + if (mul_op->type != OperatorType::kMul) { + return false; + } + const auto& output_array_name = mul_op->outputs[0]; + auto& output_array = model->GetArray(output_array_name); + + // Yield if the output shape is not known yet. + if (!output_array.has_shape()) { + return false; + } + + // This transformation only handles the case where one operand is all 0's and + // the other is non-constant. Other cases are handled by constant propagation + // or the trivial binary removal pass. + const bool is_input_constant[2] = { + IsConstantParameterArray(*model, mul_op->inputs[0]), + IsConstantParameterArray(*model, mul_op->inputs[1]), + }; + if (!is_input_constant[0] && !is_input_constant[1]) { + // Neither input is constant, so nothing we can resolve here. + return false; + } + if (is_input_constant[0] && is_input_constant[1]) { + // Both inputs are constants. That's a job for constants propagation, not + // for us to handle here. + return false; + } + const int index_of_constant_input = is_input_constant[0] ? 0 : 1; + const int index_of_variable_input = is_input_constant[0] ? 1 : 0; + CHECK(is_input_constant[index_of_constant_input]); + CHECK(!is_input_constant[index_of_variable_input]); + + const auto& constant_input_array = + model->GetArray(mul_op->inputs[index_of_constant_input]); + + CHECK(constant_input_array.data_type == output_array.data_type); + switch (output_array.data_type) { + case ArrayDataType::kFloat: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + case ArrayDataType::kUint8: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + case ArrayDataType::kInt32: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + case ArrayDataType::kInt64: { + const auto& constant_input_data = + constant_input_array.GetBuffer().data; + if (!AreAllBufferElementsZero>( + constant_input_data)) { + return false; + } + FillArrayWithZeros(&output_array); + } break; + default: + AddMessageF( + "Cannot resolve multiply by 0 because of unsupported data type\n"); + return false; + } + + // Erase input arrays to the multiply if no longer used + if (IsDiscardableArray(*model, mul_op->inputs[0]) && + CountOpsWithInput(*model, mul_op->inputs[0]) == 1) { + model->EraseArray(mul_op->inputs[0]); + } + if (IsDiscardableArray(*model, mul_op->inputs[1]) && + CountOpsWithInput(*model, mul_op->inputs[1]) == 1) { + model->EraseArray(mul_op->inputs[1]); + } + + // Erase the multiply operator. + model->operators.erase(mul_it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc index 41d6c832f0c635dd80318ec277bcf7adb9986f2a..27d2f33a8d278156262753e6572c10ff967bda4c 100644 --- a/tensorflow/contrib/lite/toco/import_tensorflow.cc +++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc @@ -21,6 +21,7 @@ limitations under the License. #include "google/protobuf/map.h" #include "google/protobuf/text_format.h" +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" @@ -364,7 +365,7 @@ void ConvertConvOperator(const NodeDef& node, // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (node.attr().count("data_format")) { + if (HasAttr(node, "data_format")) { CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); } CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); @@ -398,6 +399,17 @@ void ConvertConvOperator(const NodeDef& node, CHECK_EQ(strides.i(3), 1); conv->stride_height = strides.i(1); conv->stride_width = strides.i(2); + if (HasAttr(node, "dilations")) { + const auto& dilations = GetListAttr(node, "dilations"); + CHECK_EQ(dilations.i_size(), 4); + CHECK_EQ(dilations.i(0), 1); + CHECK_EQ(dilations.i(3), 1); + conv->dilation_height_factor = dilations.i(1); + conv->dilation_width_factor = dilations.i(2); + } else { + conv->dilation_height_factor = 1; + conv->dilation_width_factor = 1; + } const auto& padding = GetStringAttr(node, "padding"); if (padding == "SAME") { conv->padding.type = PaddingType::kSame; @@ -417,7 +429,7 @@ void ConvertDepthwiseConvOperator(const NodeDef& node, // We only support NHWC, which is the default data_format. // So if data_format is not defined, we're all good. - if (node.attr().count("data_format")) { + if (HasAttr(node, "data_format")) { CHECK_EQ(GetStringAttr(node, "data_format"), "NHWC"); } CHECK_EQ(GetDataTypeAttr(node, "T"), DT_FLOAT); @@ -1460,6 +1472,17 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node, model->operators.emplace_back(op); } +void ConvertExpOperator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK_EQ(node.op(), "Exp"); + CheckInputsCount(node, tf_import_flags, 1); + auto* op = new ExpOperator; + op->inputs.push_back(node.input(0)); + op->outputs.push_back(node.name()); + model->operators.emplace_back(op); +} + void ConvertMeanOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { @@ -1570,7 +1593,7 @@ void ConvertFloorDivOperator(const NodeDef& node, void ConvertFloorModOperator(const NodeDef& node, const TensorFlowImportFlags& tf_import_flags, Model* model) { - CHECK(node.op() == "FloorMod"); + CHECK_EQ(node.op(), "FloorMod"); CheckInputsCount(node, tf_import_flags, 2); auto* op = new FloorModOperator; op->inputs.push_back(node.input(0)); @@ -1805,6 +1828,37 @@ bool InlineAllFunctions(GraphDef* graphdef) { } return graph_modified; } + +void ConvertTopKV2Operator(const NodeDef& node, + const TensorFlowImportFlags& tf_import_flags, + Model* model) { + CHECK((node.op() == "TopK") || (node.op() == "TopKV2")); + auto op = absl::make_unique(); + op->inputs.push_back(node.input(0)); + // K can be encoded as attr (TopK) convert it to a const. + if (HasAttr(node, "k")) { + // Convert attribute into const tensor. + const string array_name = node.name() + "k"; + auto& array = model->GetOrCreateArray(array_name); + array.data_type = ArrayDataType::kInt32; + // Size of array is always 1. + array.mutable_shape()->mutable_dims()->emplace_back(1); + + auto& output_int_data = + array.GetMutableBuffer().data; + output_int_data.resize(1); + output_int_data[0] = GetIntAttr(node, "k"); + op->inputs.push_back(array_name); + + } else { + CheckInputsCount(node, tf_import_flags, 2); + op->inputs.push_back(node.input(1)); + } + // The op has two outputs. + op->outputs.push_back(node.name() + ":0"); + op->outputs.push_back(node.name() + ":1"); + model->operators.emplace_back(op.release()); +} } // namespace std::unique_ptr ImportTensorFlowGraphDef( @@ -1986,6 +2040,10 @@ std::unique_ptr ImportTensorFlowGraphDef( ConvertTransposeOperator(node, tf_import_flags, model); } else if (node.op() == "ArgMax") { ConvertArgMaxOperator(node, tf_import_flags, model); + } else if (node.op() == "Exp") { + ConvertExpOperator(node, tf_import_flags, model); + } else if (node.op() == "TopK" || node.op() == "TopKV2") { + ConvertTopKV2Operator(node, tf_import_flags, model); } else { ConvertUnsupportedOperator(node, tf_import_flags, model); } diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index 0bee694387c3da58abaf5a065c9960ceb3f84aba..346859ab392d257355b21411a1b3691c8dda5421 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -44,6 +44,7 @@ enum class OperatorType { kSpaceToDepth, kDequantize, kDiv, + kExp, kExpandDims, kFill, kFloorDiv, @@ -113,6 +114,7 @@ enum class OperatorType { kTensorFlowSwitch, kTensorFlowTile, kTranspose, + kTopK_V2, // An unsupported TF operation. It's only needed to be able to represent TF // graph internally and is expected to be dropped by graph transformations. kTensorFlowUnsupported, @@ -158,17 +160,17 @@ enum class AxesOrder { // may be involved only in debug-only subgraphs that we may not be interested // in actually supporting). enum class ArrayDataType { - kNone, + kNone, // 0 kBool, kFloat, kInt8, kUint8, - kInt16, + kInt16, // 5 kUint16, kInt32, kUint32, kInt64, - kUint64, + kUint64, // 10 kString }; @@ -357,7 +359,8 @@ struct ConvOperator : Operator { // A dilation_rate of 0 is invalid and this field is an optional attribute. // Thus initializing it to 1 to allow default conv behavior when the // attribute is not present. - int dilation_rate = 1; + int dilation_width_factor = 1; + int dilation_height_factor = 1; }; // Depthwise-separable convolution operator. @@ -852,6 +855,17 @@ struct TransposeConvOperator : Operator { int stride_height = 0; }; +// Given a tensor input, this operation calculates element-wise exponential +// (y = e^x). +// +// Inputs: +// inputs[0]: required: input tensor +// +// TensorFlow equivalent: Exp +struct ExpOperator : Operator { + ExpOperator() : Operator(OperatorType::kExp) {} +}; + // Given a tensor input, this operation inserts a dimension of 1 at the // dimension index axis of input's shape. The dimension index axis starts at // zero; if you specify a negative number for axis it is counted backward from @@ -1388,6 +1402,14 @@ struct SvdfOperator : Operator { int rank; }; +// TopKV2 operator. +// +// Inputs: +// input tensor and top_k scalar. +struct TopKV2Operator : Operator { + TopKV2Operator() : Operator(OperatorType::kTopK_V2) {} +}; + // Alloc's are used for transient arrays only. An Alloc specifies which interval // of the "transient_data" workspace buffer passed to inference functions, is to // be used for the transient array at hand. The 'start' and 'end' values are diff --git a/tensorflow/contrib/lite/toco/model_flags.proto b/tensorflow/contrib/lite/toco/model_flags.proto index e4b39b34e85e4d703c1b41cb68f8139abd1f6279..867b86f31d16b502a7aeb92cb3d8c96117630cd2 100644 --- a/tensorflow/contrib/lite/toco/model_flags.proto +++ b/tensorflow/contrib/lite/toco/model_flags.proto @@ -96,9 +96,11 @@ message RnnState { // model that does not already contain such MinMax information. message ArraysExtraInfo { message Entry { + // Next ID to use: 5. optional string name = 1; optional float min = 2; optional float max = 3; + optional IODataType data_type = 4; } repeated Entry entries = 1; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index ff54b350bf072dc8f725dea9bc19544caab3fa70..f2cc4ef71f71902e363ac4cddd3695446af30c7d 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -601,15 +601,21 @@ class Squeeze } }; -class Split : public CustomOperator { +class Split + : public BuiltinOperator { public: - using CustomOperator::CustomOperator; - void WriteOptions(const TocoOperator& op, - flexbuffers::Builder* fbb) const override { - fbb->Int("num_split", op.num_split); + using BuiltinOperator::BuiltinOperator; + + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateSplitOptions(*builder, op.num_split); } - void ReadOptions(const flexbuffers::Map& m, TocoOperator* op) const override { - op->num_split = m["num_split"].AsInt64(); + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->num_split = options.num_splits(); } }; @@ -637,6 +643,20 @@ class StridedSlice } }; +class TopK_V2 : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateTopKV2Options(*builder); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override {} +}; + class TensorFlowUnsupported : public BaseOperator { public: using BaseOperator::BaseOperator; @@ -799,8 +819,12 @@ std::vector> BuildOperatorList() { OperatorType::kResizeBilinear)); ops.emplace_back( new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze)); + ops.emplace_back(new Split(::tflite::BuiltinOperator_SPLIT, + OperatorType::kTensorFlowSplit)); ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE, OperatorType::kStridedSlice)); + ops.emplace_back( + new TopK_V2(::tflite::BuiltinOperator_TOPK_V2, OperatorType::kTopK_V2)); ops.emplace_back( new Lstm(::tflite::BuiltinOperator_LSTM, OperatorType::kLstmCell)); @@ -809,7 +833,6 @@ std::vector> BuildOperatorList() { ops.emplace_back( new DepthToSpace("DEPTH_TO_SPACE", OperatorType::kDepthToSpace)); ops.emplace_back(new FakeQuant("FAKE_QUANT", OperatorType::kFakeQuant)); - ops.emplace_back(new Split("SPLIT", OperatorType::kTensorFlowSplit)); ops.emplace_back(new TensorFlowUnsupported( "TENSORFLOW_UNSUPPORTED", OperatorType::kTensorFlowUnsupported)); @@ -835,6 +858,9 @@ std::vector> BuildOperatorList() { "LOGISTIC", OperatorType::kLogistic)); ops.emplace_back( new SimpleOperator("TANH", OperatorType::kTanh)); + ops.emplace_back(new SimpleOperator("EXP", OperatorType::kExp)); + ops.emplace_back(new SimpleOperator( + "LOG_SOFTMAX", OperatorType::kLogSoftmax)); return ops; } diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index 796534be53cba0ea772e974cd8173c0b4c12e6c3..9c19f8d4649acf40fdd85b78874f7b18798533f2 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -106,6 +106,9 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("RELU6", OperatorType::kRelu6); CheckSimpleOperator("LOGISTIC", OperatorType::kLogistic); CheckSimpleOperator("TANH", OperatorType::kTanh); + CheckSimpleOperator("EXP", OperatorType::kExp); + CheckSimpleOperator("LOG_SOFTMAX", + OperatorType::kLogSoftmax); } TEST_F(OperatorTest, BuiltinAdd) { @@ -379,6 +382,13 @@ TEST_F(OperatorTest, StridedSlice) { EXPECT_EQ(op.shrink_axis_mask, output_toco_op->shrink_axis_mask); } +TEST_F(OperatorTest, BuiltinTopKV2) { + TopKV2Operator op; + auto output_toco_op = SerializeAndDeserialize( + GetOperator("TOPK_V2", OperatorType::kTopK_V2), op); + ASSERT_NE(nullptr, output_toco_op.get()); +} + TEST_F(OperatorTest, TensorFlowUnsupported) { TensorFlowUnsupportedOperator op; op.tensorflow_op = "MyCustomUnsupportedOp"; diff --git a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc index c5a62fdb620ee7d6b7195f6e8e2bc3cb208feb10..0f67c2de728532b5b8101b3514811a78a3b3bc38 100644 --- a/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc +++ b/tensorflow/contrib/lite/toco/toco_cmdline_flags.cc @@ -112,6 +112,11 @@ bool ParseTocoFlagsFromCommandLineFlags( "If true, ignore control dependency requirements in input TensorFlow " "GraphDef. Otherwise an error will be raised upon control dependency " "inputs."), + Flag("debug_disable_recurrent_cell_fusion", + parsed_flags.debug_disable_recurrent_cell_fusion.bind(), + parsed_flags.debug_disable_recurrent_cell_fusion.default_value(), + "If true, disable fusion of known identifiable cell subgraphs into " + "cells. This includes, for example, specific forms of LSTM cell."), }; bool asked_for_help = *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help")); diff --git a/tensorflow/contrib/lite/toco/toco_flags.proto b/tensorflow/contrib/lite/toco/toco_flags.proto index 3b9d7e22570b66aef2c9fc819e5ab4ec38e179f5..3237147a736f97f65953ca965420fcea934820a4 100644 --- a/tensorflow/contrib/lite/toco/toco_flags.proto +++ b/tensorflow/contrib/lite/toco/toco_flags.proto @@ -36,7 +36,8 @@ enum FileFormat { // are not normally encoded in model files and in general may not be thought // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. -// Next Id: 13 +// +// Next ID to use: 14. message TocoFlags { // Input file format optional FileFormat input_format = 1; @@ -136,4 +137,8 @@ message TocoFlags { // - Default to false if the output format is TENSORFLOW_GRAPHDEF. // - Default to true in all other cases. optional bool drop_control_dependency = 12; + + // Disables transformations that fuse subgraphs such as known LSTMs (not all + // LSTMs are identified). + optional bool debug_disable_recurrent_cell_fusion = 13; } diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 5472c52c96ab93a6d3acf0522651d0f8876e08ce..a09a3c4ef56edc6ba7fd19eb1ff45a2e41cf3dd2 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -86,6 +86,8 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveTensorFlowSwitch); transformations->Add(new ResolveTensorFlowTile); transformations->Add(new ResolveTensorFlowConcat); + transformations->Add(new ResolveMultiplyByZero); + transformations->Add(new IdentifyDilatedConv); transformations->Add(new IdentifyL2Normalization); transformations->Add(new IdentifyL2Pool); transformations->Add(new IdentifyRelu1); @@ -188,20 +190,23 @@ std::unique_ptr Import(const TocoFlags& toco_flags, } void Transform(const TocoFlags& toco_flags, Model* model) { + // Clean up after import. + SetFinalDataTypeOnInputs(toco_flags, model); + UseArraysExtraInfo(model); + FinishBuildingRNNStates(model); + const FileFormat output_format = toco_flags.output_format(); const IODataType inference_type = toco_flags.inference_type(); const bool quantize_output = - SupportsQuantization(output_format) && inference_type == QUANTIZED_UINT8; + SupportsQuantization(output_format) && + (inference_type == QUANTIZED_UINT8 || inference_type == QUANTIZED_INT16); if (quantize_output) { QCHECK_NE(toco_flags.inference_input_type(), FLOAT) << "Quantized inference is not allowed with float inputs."; } - SetFinalDataTypeOnInputs(toco_flags, model); - UseArraysExtraInfo(model); - // Remove unused ops before performing any other optimizations. This is to // stop optimizations from crossing the input/output boundaries. For example // this will stop BatchNorm fusing if the output node is in between a conv @@ -231,7 +236,9 @@ void Transform(const TocoFlags& toco_flags, Model* model) { } transformations.Add(new ConvertPureConvToDepthwise); if (SupportsLstmCell(output_format)) { - transformations.Add(new IdentifyLstmCell); + if (!toco_flags.debug_disable_recurrent_cell_fusion()) { + transformations.Add(new IdentifyLstmCell); + } if (output_format == TFLITE) { transformations.Add(new toco::SplitLstmCellInputs); } else { diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc index ce0fde57f4ca053ec00926c3cb350bbb8d8bd3dc..9e725822383b06985bbb5cffdc19a759bc6d5cf3 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.cc +++ b/tensorflow/contrib/lite/toco/tooling_util.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "absl/strings/str_replace.h" +#include "absl/strings/str_split.h" #include "tensorflow/contrib/lite/toco/dump_graphviz.h" #include "tensorflow/contrib/lite/toco/model_flags.pb.h" #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h" @@ -61,6 +62,35 @@ string LogName(const Operator& op) { } } +string ArrayDataTypeName(ArrayDataType data_type) { + switch (data_type) { + case ArrayDataType::kFloat: + return "Float"; + case ArrayDataType::kInt8: + return "Int8"; + case ArrayDataType::kUint8: + return "Uint8"; + case ArrayDataType::kInt16: + return "Int16"; + case ArrayDataType::kUint16: + return "Uint16"; + case ArrayDataType::kInt32: + return "Int32"; + case ArrayDataType::kUint32: + return "Uint32"; + case ArrayDataType::kInt64: + return "Int64"; + case ArrayDataType::kUint64: + return "Uint64"; + case ArrayDataType::kString: + return "String"; + case ArrayDataType::kNone: + return "None"; + default: + LOG(FATAL) << "Unhandled array data type " << static_cast(data_type); + } +} + bool IsInputArray(const Model& model, const string& name) { for (const auto& input_array : model.flags.input_arrays()) { if (input_array.name() == name) { @@ -110,7 +140,17 @@ int CountOpsWithInput(const Model& model, const string& array_name) { } bool DeleteArrayIfUnused(const string& array_name, Model* model) { - if (CountOpsWithInput(*model, array_name) == 0) { + if (IsDiscardableArray(*model, array_name) && + CountOpsWithInput(*model, array_name) == 0) { + model->EraseArray(array_name); + return true; + } + return false; +} + +bool DeleteArrayIfUsedOnce(const string& array_name, Model* model) { + if (IsDiscardableArray(*model, array_name) && + CountOpsWithInput(*model, array_name) == 1) { model->EraseArray(array_name); return true; } @@ -302,7 +342,9 @@ const char* OperatorTypeName(OperatorType type) { HANDLE_OPERATORTYPENAME_CASE(Mean) HANDLE_OPERATORTYPENAME_CASE(Svdf) HANDLE_OPERATORTYPENAME_CASE(ArgMax) + HANDLE_OPERATORTYPENAME_CASE(TopK_V2) HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported) + HANDLE_OPERATORTYPENAME_CASE(Exp) default: LOG(FATAL) << "Unhandled op type"; #undef HANDLE_OPERATORTYPENAME_CASE @@ -321,6 +363,7 @@ string HelpfulOperatorTypeName(const Operator& op) { bool OperatorSupportsFusedActivation(OperatorType type) { switch (type) { case OperatorType::kConcatenation: + case OperatorType::kGather: case OperatorType::kSlice: case OperatorType::kSqueeze: case OperatorType::kTensorFlowReshape: @@ -349,48 +392,9 @@ void LogSummary(int log_level, const Model& model) { void LogArray(int log_level, const Model& model, const string& name) { const auto& array = model.GetArray(name); VLOG(log_level) << "Array: " << name; - switch (array.data_type) { - case ArrayDataType::kNone: - VLOG(log_level) << " Data type:"; - break; - case ArrayDataType::kFloat: - VLOG(log_level) << " Data type: kFloat"; - break; - case ArrayDataType::kInt32: - VLOG(log_level) << " Data type: kInt32"; - break; - case ArrayDataType::kUint8: - VLOG(log_level) << " Data type: kUint8"; - break; - case ArrayDataType::kString: - VLOG(log_level) << " Data type: kString"; - break; - default: - VLOG(log_level) << " Data type: other (numerical value: " - << static_cast(array.data_type) << ")"; - break; - } - switch (array.final_data_type) { - case ArrayDataType::kNone: - VLOG(log_level) << " Final type:"; - break; - case ArrayDataType::kFloat: - VLOG(log_level) << " Final type: kFloat"; - break; - case ArrayDataType::kInt32: - VLOG(log_level) << " Final type: kInt32"; - break; - case ArrayDataType::kUint8: - VLOG(log_level) << " Final type: kUint8"; - break; - case ArrayDataType::kString: - VLOG(log_level) << " Final type: kString"; - break; - default: - VLOG(log_level) << " Final type: other (numerical value: " - << static_cast(array.data_type) << ")"; - break; - } + VLOG(log_level) << " Data type: " << ArrayDataTypeName(array.data_type); + VLOG(log_level) << " Final type: " + << ArrayDataTypeName(array.final_data_type); if (array.buffer) { VLOG(log_level) << " Constant Buffer"; } @@ -620,6 +624,14 @@ bool IsConstantParameterArray(const Model& model, const string& name) { } namespace { +// Take an array name, which may be something like "name:3_5" and make it +// acceptable as a TF node name, say "name_3_5"; +string SanitizeNameForTFNode(const string& array_name) { + auto node_name = array_name; + std::replace(node_name.begin(), node_name.end(), ':', '_'); + return node_name; +} + void CheckInputArraysAreNotOutputArrays(const ModelFlags& model_flags) { for (const auto& input_array : model_flags.input_arrays()) { for (const string& output_array : model_flags.output_arrays()) { @@ -783,7 +795,10 @@ void FixNoOrphanedArray(Model* model) { } } -void CheckArrayFieldsConsistent(const Model& model) { +// Apply checks to arrays individually (for-each fashion). +// +// Check consistency of array fields, check name. +void CheckEachArray(const Model& model) { for (const auto& array_entry : model.GetArrayMap()) { const auto& array = array_entry.second; if (array->has_shape()) { @@ -798,6 +813,18 @@ void CheckArrayFieldsConsistent(const Model& model) { if (array->buffer) { CHECK(array->buffer->type == array->data_type); } + + // Check name. Either "name_with_suffix_8", "name_with_port:3", but not + // "name_with_both:3_8". + const string& name = array_entry.first; + auto colon_pos = name.find_first_of(":"); + if (colon_pos != string::npos) { + CHECK_EQ(name.substr(colon_pos + 1).find_first_not_of("0123456789"), + string::npos) + << "Array name must only have digits after colon"; + } + CHECK_GT(colon_pos, 0) + << "First character of array name must not be a colon."; } } @@ -946,7 +973,7 @@ void CheckInvariants(const Model& model) { CheckNonAsciiIOArrays(model.flags); CheckNoMissingArray(model); CheckNoOrphanedArray(model); - CheckArrayFieldsConsistent(model); + CheckEachArray(model); CheckOperatorOrdering(model); } @@ -1038,9 +1065,6 @@ void CreateOrCheckRnnStateArray(const string& name, int size, Model* model) { if (array.has_shape()) { num_dims = array.shape().dimensions_count(); } - CHECK(array.data_type == ArrayDataType::kFloat || - array.data_type == ArrayDataType::kNone); - array.data_type = ArrayDataType::kFloat; if (!array.has_shape() && num_dims >= 0) { Shape* shape = array.mutable_shape(); std::vector dims; @@ -1064,7 +1088,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) { } } if (!dst_input_array) { - // specified_input_array from model_flags is not found in model->flags. + // Specified_input_array from model_flags is not found in model->flags. // Match a name-less specified input array when there can be no ambiguity // as there is only 1 input array. if (model->flags.input_arrays_size() == 1 && @@ -1371,19 +1395,23 @@ bool IsAllocatableTransientArray(const Model& model, const string& array_name) { } string AvailableArrayName(const Model& model, const string& name) { - if (!model.HasArray(name) && !model.IsOptionalArray(name)) { - return name; + string sanitized_name = SanitizeNameForTFNode(name); + if (!model.HasArray(sanitized_name) && + !model.IsOptionalArray(sanitized_name)) { + return sanitized_name; } const int kNumSuffixesToTry = 1000; for (int i = 0; i < kNumSuffixesToTry; i++) { - const string& name_with_suffix = toco::port::StringF("%s_%d", name, i); + const string& name_with_suffix = + toco::port::StringF("%s_%d", sanitized_name, i); if (!model.HasArray(name_with_suffix) && !model.IsOptionalArray(name_with_suffix)) { return name_with_suffix; } } - LOG(FATAL) << "Could not find an available array name starting with " << name - << ". Tried " << kNumSuffixesToTry << " suffixes, all were taken!"; + LOG(FATAL) << "Could not find an available array name starting with " + << sanitized_name << ". Tried " << kNumSuffixesToTry + << " suffixes, all were taken!"; return ""; } @@ -1773,6 +1801,8 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { return ArrayDataType::kFloat; case QUANTIZED_UINT8: return ArrayDataType::kUint8; + case QUANTIZED_INT16: + return ArrayDataType::kInt16; case INT32: return ArrayDataType::kInt32; case INT64: @@ -1782,14 +1812,39 @@ ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type) { } } +void FinishBuildingRNNStates(Model* model) { + for (const auto& rnn_state : model->flags.rnn_states()) { + if (!model->HasArray(rnn_state.back_edge_source_array()) || + !model->HasArray(rnn_state.state_array())) { + CHECK(model->HasArray(rnn_state.back_edge_source_array())); + CHECK(model->HasArray(rnn_state.state_array())); + continue; + } + const auto& src_array = model->GetArray(rnn_state.back_edge_source_array()); + auto& dst_array = model->GetArray(rnn_state.state_array()); + if (src_array.data_type == ArrayDataType::kNone && + dst_array.data_type == ArrayDataType::kNone) { + dst_array.data_type = ArrayDataType::kFloat; + } + } +} + void UseArraysExtraInfo(Model* model) { for (const auto& entry : model->flags.arrays_extra_info().entries()) { QCHECK(model->HasArray(entry.name())) << "ArraysExtraInfo refers to non-existent array name: " << entry.name(); - auto& minmax = model->GetArray(entry.name()).GetOrCreateMinMax(); - minmax.min = entry.min(); - minmax.max = entry.max(); + auto& array = model->GetArray(entry.name()); + auto& minmax = array.GetOrCreateMinMax(); + if (entry.has_min() || entry.has_max()) { + CHECK_EQ(entry.has_min(), entry.has_max()); + minmax.min = entry.min(); + minmax.max = entry.max(); + } + if (entry.has_data_type()) { + array.final_data_type = + ConvertIODataTypeToArrayDataType(entry.data_type()); + } } } diff --git a/tensorflow/contrib/lite/toco/tooling_util.h b/tensorflow/contrib/lite/toco/tooling_util.h index 3addccaa10c96f5dba165d7aeea0e830a8b1ebfd..11208ed667212d56f9ef45e4f394e0bbf5000cbc 100644 --- a/tensorflow/contrib/lite/toco/tooling_util.h +++ b/tensorflow/contrib/lite/toco/tooling_util.h @@ -54,12 +54,15 @@ absl::string_view FindLongestCommonPrefix(absl::string_view a, absl::string_view b); string LogName(const Operator& op); +string ArrayDataTypeName(ArrayDataType data_type); + bool IsInputArray(const Model& model, const string& name); bool IsArrayConsumed(const Model& model, const string& name); int CountTrueOutputs(const Model& model, const Operator& op); int CountOpsWithInput(const Model& model, const string& array_name); bool DeleteArrayIfUnused(const string& array_name, Model* model); +bool DeleteArrayIfUsedOnce(const string& array_name, Model* model); std::vector>::const_iterator FindOpWithOutput( const Model& model, const string& array_name); @@ -298,6 +301,23 @@ void CheckFinalDataTypesSatisfied(const Model& model); ArrayDataType ConvertIODataTypeToArrayDataType(IODataType type); +// The process of building models varies according to the import format. +// +// (a) In some cases, such as model-proto format, the model should be fully +// specified. In these cases, no extra action should be taken by this function. +// (b) In other cases, such as TF graphdef format, the desired types of RNN +// arrays are not specified directly in the model, neither can they be inferred. +// However, we can set the types of RNN destination arrays to float. This breaks +// any cycles such as when resolution of the type of an RNN source array depends +// on the type of its destination array. +// +// This function is applied after the main import, after resolution of flags and +// after application of ArraysExtraInfo. It only defaults destination RNN arrays +// to float. If the model is subsequently quantized, it is assumed that the +// model contains sufficient information for that to be completed. If it is +// already quantized, then case (a) should hold. +void FinishBuildingRNNStates(Model* model); + void UseArraysExtraInfo(Model* model); } // namespace toco diff --git a/tensorflow/contrib/lite/toco/types.proto b/tensorflow/contrib/lite/toco/types.proto index 318fd4b7b2c2df093562e73c3fe707675ee98876..03bd6150bc86bb27221814cd191b17f1a09585fa 100644 --- a/tensorflow/contrib/lite/toco/types.proto +++ b/tensorflow/contrib/lite/toco/types.proto @@ -34,4 +34,7 @@ enum IODataType { // String, not quantized STRING = 5; + + // Int16, quantized + QUANTIZED_INT16 = 6; } diff --git a/tensorflow/contrib/lite/tools/BUILD b/tensorflow/contrib/lite/tools/BUILD index 6786b1618456637aecfd870b9984af65b59784f6..999ccf2ebc009b6b7c50a9a2d1667d69a3f690e7 100644 --- a/tensorflow/contrib/lite/tools/BUILD +++ b/tensorflow/contrib/lite/tools/BUILD @@ -112,6 +112,7 @@ cc_test( size = "small", srcs = ["verifier_test.cc"], deps = [ + ":mutable_op_resolver", ":verifier", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", diff --git a/tensorflow/contrib/lite/tools/verifier.cc b/tensorflow/contrib/lite/tools/verifier.cc index 726e2aaa3162591593cd2abd6384eb55baf0aef4..59c74205f0a311ec12ff87f46622041605fb493b 100644 --- a/tensorflow/contrib/lite/tools/verifier.cc +++ b/tensorflow/contrib/lite/tools/verifier.cc @@ -155,11 +155,11 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { } for (const auto& subgraph : *model.subgraphs()) { if (!subgraph->tensors()) { - return true; + continue; } for (const auto& tensor : *subgraph->tensors()) { if (!tensor->buffer()) { - return true; + continue; } if (tensor->buffer() >= model.buffers()->size()) { ReportError(error_reporter, "Invalid tensor buffer index: %d", @@ -187,9 +187,33 @@ bool VerifyTensors(const Model& model, ErrorReporter* error_reporter) { return true; } +bool VerifyOps(const Model& model, const OpResolver& resolver, + ErrorReporter* error_reporter) { + if (!model.operator_codes()) { + return true; + } + for (const auto& opcode : *model.operator_codes()) { + if (opcode->builtin_code() == BuiltinOperator_CUSTOM) { + if (!resolver.FindOp(opcode->custom_code()->c_str())) { + ReportError(error_reporter, "Unsupported custom op: %s", + opcode->custom_code()->c_str()); + return false; + } + } else { + if (!resolver.FindOp(opcode->builtin_code())) { + ReportError(error_reporter, "Unsupported builtin op: %s", + EnumNameBuiltinOperator(opcode->builtin_code())); + return false; + } + } + } + return true; +} + } // namespace -bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) { +bool Verify(const void* buf, size_t len, const OpResolver& resolver, + ErrorReporter* error_reporter) { const Model* model = VerifyFlatbufferAndGetModel(buf, len); if (model == nullptr) { ReportError(error_reporter, "Invalid flatbuffer format"); @@ -202,6 +226,9 @@ bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter) { if (!VerifyTensors(*model, error_reporter)) { return false; } + if (!VerifyOps(*model, resolver, error_reporter)) { + return false; + } return true; } } // namespace tflite diff --git a/tensorflow/contrib/lite/tools/verifier.h b/tensorflow/contrib/lite/tools/verifier.h index d2bf3c91d54225098c1f254c26971e8bb962f791..c2ee11215c861ed7b27696a8d786bb6e2a48e930 100644 --- a/tensorflow/contrib/lite/tools/verifier.h +++ b/tensorflow/contrib/lite/tools/verifier.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/model.h" namespace tflite { @@ -26,7 +27,9 @@ namespace tflite { // Currently, it verifies: // * The file is following a legit flatbuffer schema. // * The model is in supported version. -bool Verify(const void* buf, size_t len, ErrorReporter* error_reporter); +// * All ops used in the model are supported by OpResolver. +bool Verify(const void* buf, size_t len, const OpResolver& resolver, + ErrorReporter* error_reporter); } // namespace tflite diff --git a/tensorflow/contrib/lite/tools/verifier_test.cc b/tensorflow/contrib/lite/tools/verifier_test.cc index 87f6854e9e67c0389949c8d72a476036051d1c0f..b3e611f999b2837efbf8876bd989db44c408b8c7 100644 --- a/tensorflow/contrib/lite/tools/verifier_test.cc +++ b/tensorflow/contrib/lite/tools/verifier_test.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/contrib/lite/error_reporter.h" #include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h" #include "tensorflow/contrib/lite/tools/verifier.h" #include "tensorflow/contrib/lite/version.h" #include "tensorflow/core/framework/numeric_types.h" @@ -40,6 +41,19 @@ class TfLiteFlatbufferModelBuilder { CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); } + TfLiteFlatbufferModelBuilder(const std::vector& builtin_ops, + const std::vector& custom_ops) { + buffers_.push_back( + CreateBuffer(builder_, builder_.CreateVector(std::vector{}))); + + for (const auto& iter : builtin_ops) { + resolver_.AddBuiltin(iter, &fake_op_); + } + for (const auto& iter : custom_ops) { + resolver_.AddCustom(iter.data(), &fake_op_); + } + } + void AddTensor(const std::vector& shape, tflite::TensorType type, const std::vector& buffer, const char* name) { int buffer_index = 0; @@ -79,11 +93,13 @@ class TfLiteFlatbufferModelBuilder { bool Verify() { return tflite::Verify(builder_.GetBufferPointer(), builder_.GetSize(), - DefaultErrorReporter()); + resolver_, DefaultErrorReporter()); } private: FlatBufferBuilder builder_; + MutableOpResolver resolver_; + TfLiteRegistration fake_op_; std::vector> operators_; std::vector> operator_codes_; std::vector> tensors_; @@ -98,11 +114,11 @@ TEST(VerifyModel, TestEmptyModel) { ::tflite::FinishModelBuffer(builder, model); ASSERT_TRUE(Verify(builder.GetBufferPointer(), builder.GetSize(), - DefaultErrorReporter())); + MutableOpResolver{}, DefaultErrorReporter())); } TEST(VerifyModel, TestSimpleModel) { - TfLiteFlatbufferModelBuilder builder; + TfLiteFlatbufferModelBuilder builder({}, {"test"}); builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "test"); builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4, 5, 6}, "input"); builder.AddTensor( @@ -116,7 +132,8 @@ TEST(VerifyModel, TestSimpleModel) { TEST(VerifyModel, TestCorruptedData) { std::string model = "123"; - ASSERT_FALSE(Verify(model.data(), model.size(), /*error_reporter=*/nullptr)); + ASSERT_FALSE(Verify(model.data(), model.size(), MutableOpResolver{}, + /*error_reporter=*/nullptr)); } TEST(VerifyModel, TestUnsupportedVersion) { @@ -125,7 +142,7 @@ TEST(VerifyModel, TestUnsupportedVersion) { /*subgraphs=*/0, /*description=*/0, /*buffers=*/0); ::tflite::FinishModelBuffer(builder, model); ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), - DefaultErrorReporter())); + MutableOpResolver{}, DefaultErrorReporter())); } TEST(VerifyModel, TestRandomModificationIsNotAllowed) { @@ -140,7 +157,7 @@ TEST(VerifyModel, TestRandomModificationIsNotAllowed) { for (int i = 0; i < model_content.size(); i++) { model_content[i] = (model_content[i] + 137) % 255; EXPECT_FALSE(Verify(model_content.data(), model_content.size(), - DefaultErrorReporter())) + MutableOpResolver{}, DefaultErrorReporter())) << "Fail at position: " << i; } } @@ -188,7 +205,7 @@ TEST(VerifyModel, TensorBufferIsNotValid) { ::tflite::FinishModelBuffer(builder, model); ASSERT_FALSE(Verify(builder.GetBufferPointer(), builder.GetSize(), - DefaultErrorReporter())); + MutableOpResolver{}, DefaultErrorReporter())); } TEST(VerifyModel, StringTensorHasInvalidNumString) { @@ -229,6 +246,37 @@ TEST(VerifyModel, StringTensorIsLargerThanRequired) { ASSERT_FALSE(builder.Verify()); } +TEST(VerifyModel, AllOpsAreSupported) { + TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"CustomOp"}); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2"); + builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output"); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "CustomOp"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, UseUnsupportedBuiltinOps) { + TfLiteFlatbufferModelBuilder builder({BuiltinOperator_SUB}, {"CustomOp"}); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2"); + builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output"); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_ADD, nullptr); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + +TEST(VerifyModel, UseUnsupportedCustomOps) { + TfLiteFlatbufferModelBuilder builder({BuiltinOperator_ADD}, {"NewOp"}); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input1"); + builder.AddTensor({2, 3}, TensorType_UINT8, {1, 2, 3, 4}, "input2"); + builder.AddTensor({2, 3}, TensorType_UINT8, {}, "output"); + builder.AddOperator({0, 1}, {2}, BuiltinOperator_CUSTOM, "Not supported"); + builder.FinishModel({}, {}); + ASSERT_FALSE(builder.Verify()); +} + // TODO(yichengfan): make up malicious files to test with. } // namespace tflite diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py index 6842bc38eb108b46cc3eff715c9cbc74f991308b..2b9eee4ef7b418e2b90d388d2f165537b8660a9a 100644 --- a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -50,16 +50,12 @@ def pairwise_distance(feature, squared=False): pairwise_distances: 2-D Tensor of size [number of data, number of data]. """ pairwise_distances_squared = math_ops.add( + math_ops.reduce_sum(math_ops.square(feature), axis=[1], keepdims=True), math_ops.reduce_sum( - math_ops.square(feature), - axis=[1], - keepdims=True), - math_ops.reduce_sum( - math_ops.square( - array_ops.transpose(feature)), + math_ops.square(array_ops.transpose(feature)), axis=[0], - keepdims=True)) - 2.0 * math_ops.matmul( - feature, array_ops.transpose(feature)) + keepdims=True)) - 2.0 * math_ops.matmul(feature, + array_ops.transpose(feature)) # Deal with numerical inaccuracies. Set small negatives to zero. pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0) @@ -134,8 +130,8 @@ def masked_maximum(data, mask, dim=1): """ axis_minimums = math_ops.reduce_min(data, dim, keepdims=True) masked_maximums = math_ops.reduce_max( - math_ops.multiply( - data - axis_minimums, mask), dim, keepdims=True) + axis_minimums + math_ops.multiply(data - axis_minimums, mask), dim, + keepdims=True) + axis_minimums return masked_maximums @@ -153,8 +149,8 @@ def masked_minimum(data, mask, dim=1): """ axis_maximums = math_ops.reduce_max(data, dim, keepdims=True) masked_minimums = math_ops.reduce_min( - math_ops.multiply( - data - axis_maximums, mask), dim, keepdims=True) + axis_maximums + math_ops.multiply(data - axis_maximums, mask), dim, + keepdims=True) + axis_maximums return masked_minimums @@ -202,8 +198,7 @@ def triplet_semihard_loss(labels, embeddings, margin=1.0): mask_final = array_ops.reshape( math_ops.greater( math_ops.reduce_sum( - math_ops.cast( - mask, dtype=dtypes.float32), 1, keepdims=True), + math_ops.cast(mask, dtype=dtypes.float32), 1, keepdims=True), 0.0), [batch_size, batch_size]) mask_final = array_ops.transpose(mask_final) @@ -450,8 +445,8 @@ def lifted_struct_loss(labels, embeddings, margin=1.0): # this is to take the max only among negatives. row_minimums = math_ops.reduce_min(diff, 1, keepdims=True) row_negative_maximums = math_ops.reduce_max( - math_ops.multiply( - diff - row_minimums, mask), 1, keepdims=True) + row_minimums + math_ops.multiply(diff - row_minimums, mask), 1, + keepdims=True) + row_minimums # Compute the loss. # Keep track of matrix of maximums where M_ij = max(m_i, m_j) @@ -467,10 +462,11 @@ def lifted_struct_loss(labels, embeddings, margin=1.0): array_ops.transpose(max_elements), [-1, 1]) loss_exp_left = array_ops.reshape( - math_ops.reduce_sum(math_ops.multiply( - math_ops.exp( - diff_tiled - max_elements_vect), - mask_tiled), 1, keepdims=True), [batch_size, batch_size]) + math_ops.reduce_sum( + math_ops.multiply( + math_ops.exp(diff_tiled - max_elements_vect), mask_tiled), + 1, + keepdims=True), [batch_size, batch_size]) loss_mat = max_elements + math_ops.log( loss_exp_left + array_ops.transpose(loss_exp_left)) diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index 81327407d44b4317b7aecb964a689a35aa35c163..05e8d9064bea748c935859f5f9b4c7e646f504cf 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -677,6 +677,7 @@ endif # TEGRA TF_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS)) # Add in any extra files that don't fit the patterns easily TF_CC_SRCS += tensorflow/contrib/makefile/downloads/fft2d/fftsg.c +TF_CC_SRCS += tensorflow/core/common_runtime/gpu/gpu_id_manager.cc # Also include the op and kernel definitions. TF_CC_SRCS += $(shell cat $(MAKEFILE_DIR)/tf_op_files.txt) PBT_CC_SRCS := $(shell cat $(MAKEFILE_DIR)/tf_pb_text_files.txt) diff --git a/tensorflow/contrib/makefile/README.md b/tensorflow/contrib/makefile/README.md index 6959ca344fa574eab85aca2386f94265d547b7f9..995230dfa848532dc2a50b85f58d19ba264f293e 100644 --- a/tensorflow/contrib/makefile/README.md +++ b/tensorflow/contrib/makefile/README.md @@ -130,6 +130,105 @@ adb shell '/data/local/tmp/benchmark \ For more details, see the [benchmark documentation](../../tools/benchmark). +## CUDA support for Tegra devices running Android (Nvidia Shield TV, etc) + +With the release of TF 1.6 and JetPack for Android 3.2 (currently pending), you can now build a version of TensorFlow for compatible devices according to the following instructions which will receive the full benefits of GPU acceleration. + +#### Environment setup: + +First, download and install JetPack for Android version 3.2 or greater from [Nvidia](https://developers.nvidia.com). Note that as of the TF 1.6 release the JetPack for Android 3.2 release is still pending, and regular JetPack for L4T will not work. + +```bash +git clone https://github.com/tensorflow/tensorflow.git +cd tensorflow +JETPACK=$HOME/JetPack_Android_3.2 +TEGRA_LIBS="$JETPACK/cuDNN/aarch64/cuda/lib64/libcudnn.so $JETPACK/cuda-9.0/extras/CUPTI/lib64/libcupti.so $JETPACK/cuda/targets/aarch64-linux-androideabi/lib64/libcufft.so" +``` + +#### Building all CUDA-enabled native binaries: +This will build CUDA-enabled versions of libtensorflow_inference.so and the benchmark binary. (libtensorflow_demo.so will also be built incidentally, but it does not support CUDA) + +```bash +NDK_ROOT=$JETPACK/android-ndk-r13b +CC_PREFIX=ccache tensorflow/contrib/makefile/build_all_android.sh -s tensorflow/contrib/makefile/sub_makefiles/android/Makefile.in -t "libtensorflow_inference.so libtensorflow_demo.so all" -a tegra +``` +(add -T on subsequent builds to skip protobuf downloading/building) + + +#### Testing the CUDA-enabled benchmark via adb: +Build binaries first as above, then run: + +```bash +adb shell mkdir -p /data/local/tmp/lib64 +adb push $TEGRA_LIBS /data/local/tmp/lib64 +adb push tensorflow/contrib/makefile/gen/bin/android_arm64-v8a/benchmark /data/local/tmp +wget https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk +unzip tensorflow_demo.apk -d /tmp/tensorflow_demo +adb push /tmp/tensorflow_demo/assets/*.pb /data/local/tmp +adb shell "LD_LIBRARY_PATH=/data/local/tmp/lib64 /data/local/tmp/benchmark --graph=/data/local/tmp/tensorflow_inception_graph.pb" +``` + +#### Building the CUDA-enabled TensorFlow AAR with Bazel: +Build the native binaries first as above. Then, build the aar and package the native libs by executing the following: +```bash +mkdir -p /tmp/tf/jni/arm64-v8a +cp tensorflow/contrib/makefile/gen/lib/android_tegra/libtensorflow_*.so /tmp/tf/jni/arm64-v8a/ +cp $TEGRA_LIBS /tmp/tf/jni/arm64-v8a +bazel build //tensorflow/contrib/android:android_tensorflow_inference_java.aar +cp bazel-bin/tensorflow/contrib/android/android_tensorflow_inference_java.aar /tmp/tf/tensorflow.aar +cd /tmp/tf +chmod +w tensorflow.aar +zip -ur tensorflow.aar $(find jni -name *.so) +``` + +#### Building the CUDA-enabled TensorFlow Android demo with Bazel: +Build binaries first as above, then edit tensorflow/examples/android/BUILD and replace: +``` + srcs = [ + ":libtensorflow_demo.so", + "//tensorflow/contrib/android:libtensorflow_inference.so", + ], +``` +with: +``` +srcs = glob(["libs/arm64-v8a/*.so"]), +``` + +Then run: +```bash +# Create dir for native libs +mkdir -p tensorflow/examples/android/libs/arm64-v8a + +# Copy JetPack libs +cp $TEGRA_LIBS tensorflow/examples/android/libs/arm64-v8a + +# Copy native TensorFlow libraries +cp tensorflow/contrib/makefile/gen/lib/android_arm64-v8a/libtensorflow_*.so tensorflow/examples/android/libs/arm64-v8a/ + +# Build APK +bazel build -c opt --fat_apk_cpu=arm64-v8a tensorflow/android:tensorflow_demo + +# Install +adb install -r -f bazel-bin/tensorflow/examples/android/tensorflow_demo.apk +``` + +#### Building the CUDA-enabled Android demo with gradle/Android Studio: + +Add tensorflow/examples/android as an Android project in Android Studio as normal. + +Edit build.gradle and: +* set nativeBuildSystem = 'makefile' +* set cpuType = 'arm64-v8a' +* in "buildNativeMake", replace cpuType with 'tegra' (optional speedups like -T and ccache also work) +* set the environment "NDK_ROOT" var to $JETPACK/android-ndk-r13b + +Click "build apk" to build. + +Install: +```bash +adb install -r -f tensorflow/examples/android/gradleBuild/outputs/apk/debug/android-debug.apk +``` + ## iOS _Note: To use this library in an iOS application, see related instructions in diff --git a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh index 203ff4f890a3b0ed32caa1406508b100dd47bcad..421ddd210fd5b1ac6487918d5797eab5953316df 100755 --- a/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh +++ b/tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh @@ -36,7 +36,7 @@ while getopts "bc:Eps" opt_name; do b) BUILD_ONLY="true";; c) TEST_COUNT="${OPTARG}";; E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";; - p) USE_PREBUILT_HEXAOGON_BINARIES="true";; + p) USE_PREBUILT_HEXAGON_BINARIES="true";; s) SKIP_DOWNLOAD_IF_EXIST="true";; *) usage;; esac @@ -49,7 +49,7 @@ if [[ -z "${NDK_ROOT}" ]]; then exit 1 fi -if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" != "true" && +if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" != "true" && -z "${QUALCOMM_SDK}" ]]; then echo "QUALCOMM_SDK is empty" 1>&2 usage @@ -84,7 +84,7 @@ rm -rf "${GEN_DIR}" mkdir -p "${GEN_LIBS_DIR}" mkdir -p "${GEN_DOWNLOAD_DIR}" -if [[ "${USE_PREBUILT_HEXAOGON_BINARIES}" == "true" ]]; then +if [[ "${USE_PREBUILT_HEXAGON_BINARIES}" == "true" ]]; then echo "Download prebuilt hexagon binaries" if [[ "${BUILD_ONLY}" != "true" ]]; then CONTROLLER_PUSH_DEST="/data/local/tmp" diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index 9de664c822bf7a9abf7b8082f444c61dfa45f499..e90c525113348532a3ebdadde7e712bf2d98cee9 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -43,6 +43,7 @@ py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:weights_broadcast_ops", + "//tensorflow/python/ops/distributions", ], ) diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index d3dce46bfb6e9c77cc7ae107b323a9bc7074c47e..de02dc8f457364450929776035829d86035d706b 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -16,6 +16,7 @@ See the @{$python/contrib.metrics} guide. +@@auc_with_confidence_intervals @@streaming_accuracy @@streaming_mean @@streaming_recall @@ -83,6 +84,7 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics +from tensorflow.contrib.metrics.python.ops.metric_ops import auc_with_confidence_intervals from tensorflow.contrib.metrics.python.ops.metric_ops import cohen_kappa from tensorflow.contrib.metrics.python.ops.metric_ops import count from tensorflow.contrib.metrics.python.ops.metric_ops import precision_recall_at_equal_thresholds diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index d3ce51a6112d955d012b4532ac727bf146f2c5cd..31e274c5fd7c670458b1b40a4f58c668a23776c7 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -38,6 +38,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops +from tensorflow.python.ops.distributions.normal import Normal from tensorflow.python.util.deprecation import deprecated # Epsilon constant used to represent extremely small quantity. @@ -1196,6 +1197,295 @@ def streaming_dynamic_auc(labels, return auc, update_op +def _compute_placement_auc(labels, predictions, weights, alpha, + logit_transformation, is_valid): + """Computes the AUC and asymptotic normally distributed confidence interval. + + The calculations are achieved using the fact that AUC = P(Y_1>Y_0) and the + concept of placement values for each labeled group, as presented by Delong and + Delong (1988). The actual algorithm used is a more computationally efficient + approach presented by Sun and Xu (2014). This could be slow for large batches, + but has the advantage of not having its results degrade depending on the + distribution of predictions. + + Args: + labels: A `Tensor` of ground truth labels with the same shape as + `predictions` with values of 0 or 1 and type `int64`. + predictions: A 1-D `Tensor` of predictions whose values are `float64`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`. + alpha: Confidence interval level desired. + logit_transformation: A boolean value indicating whether the estimate should + be logit transformed prior to calculating the confidence interval. Doing + so enforces the restriction that the AUC should never be outside the + interval [0,1]. + is_valid: A bool tensor describing whether the input is valid. + + Returns: + A 1-D `Tensor` containing the area-under-curve, lower, and upper confidence + interval values. + """ + # Disable the invalid-name checker so that we can capitalize the name. + # pylint: disable=invalid-name + AucData = collections_lib.namedtuple('AucData', ['auc', 'lower', 'upper']) + # pylint: enable=invalid-name + + # If all the labels are the same or if number of observations are too few, + # AUC isn't well-defined + size = array_ops.size(predictions, out_type=dtypes.int32) + + # Count the total number of positive and negative labels in the input. + total_0 = math_ops.reduce_sum( + math_ops.cast(1 - labels, weights.dtype) * weights) + total_1 = math_ops.reduce_sum( + math_ops.cast(labels, weights.dtype) * weights) + + # Sort the predictions ascending, as well as + # (i) the corresponding labels and + # (ii) the corresponding weights. + ordered_predictions, indices = nn.top_k(predictions, k=size, sorted=True) + ordered_predictions = array_ops.reverse( + ordered_predictions, axis=array_ops.zeros(1, dtypes.int32)) + indices = array_ops.reverse(indices, axis=array_ops.zeros(1, dtypes.int32)) + ordered_labels = array_ops.gather(labels, indices) + ordered_weights = array_ops.gather(weights, indices) + + # We now compute values required for computing placement values. + + # We generate a list of indices (segmented_indices) of increasing order. An + # index is assigned for each unique prediction float value. Prediction + # values that are the same share the same index. + _, segmented_indices = array_ops.unique(ordered_predictions) + + # We create 2 tensors of weights. weights_for_true is non-zero for true + # labels. weights_for_false is non-zero for false labels. + float_labels_for_true = math_ops.cast(ordered_labels, dtypes.float32) + float_labels_for_false = 1.0 - float_labels_for_true + weights_for_true = ordered_weights * float_labels_for_true + weights_for_false = ordered_weights * float_labels_for_false + + # For each set of weights with the same segmented indices, we add up the + # weight values. Note that for each label, we deliberately rely on weights + # for the opposite label. + weight_totals_for_true = math_ops.segment_sum(weights_for_false, + segmented_indices) + weight_totals_for_false = math_ops.segment_sum(weights_for_true, + segmented_indices) + + # These cumulative sums of weights importantly exclude the current weight + # sums. + cum_weight_totals_for_true = math_ops.cumsum(weight_totals_for_true, + exclusive=True) + cum_weight_totals_for_false = math_ops.cumsum(weight_totals_for_false, + exclusive=True) + + # Compute placement values using the formula. Values with the same segmented + # indices and labels share the same placement values. + placements_for_true = ( + (cum_weight_totals_for_true + weight_totals_for_true / 2.0) / + (math_ops.reduce_sum(weight_totals_for_true) + _EPSILON)) + placements_for_false = ( + (cum_weight_totals_for_false + weight_totals_for_false / 2.0) / + (math_ops.reduce_sum(weight_totals_for_false) + _EPSILON)) + + # We expand the tensors of placement values (for each label) so that their + # shapes match that of predictions. + placements_for_true = array_ops.gather(placements_for_true, segmented_indices) + placements_for_false = array_ops.gather(placements_for_false, + segmented_indices) + + # Select placement values based on the label for each index. + placement_values = ( + placements_for_true * float_labels_for_true + + placements_for_false * float_labels_for_false) + + # Split placement values by labeled groups. + placement_values_0 = placement_values * math_ops.cast( + 1 - ordered_labels, weights.dtype) + weights_0 = ordered_weights * math_ops.cast( + 1 - ordered_labels, weights.dtype) + placement_values_1 = placement_values * math_ops.cast( + ordered_labels, weights.dtype) + weights_1 = ordered_weights * math_ops.cast( + ordered_labels, weights.dtype) + + # Calculate AUC using placement values + auc_0 = (math_ops.reduce_sum(weights_0 * (1. - placement_values_0)) / + (total_0 + _EPSILON)) + auc_1 = (math_ops.reduce_sum(weights_1 * (placement_values_1)) / + (total_1 + _EPSILON)) + auc = array_ops.where(math_ops.less(total_0, total_1), auc_1, auc_0) + + # Calculate variance and standard error using the placement values. + var_0 = ( + math_ops.reduce_sum( + weights_0 * math_ops.square(1. - placement_values_0 - auc_0)) / + (total_0 - 1. + _EPSILON)) + var_1 = ( + math_ops.reduce_sum( + weights_1 * math_ops.square(placement_values_1 - auc_1)) / + (total_1 - 1. + _EPSILON)) + auc_std_err = math_ops.sqrt( + (var_0 / (total_0 + _EPSILON)) + (var_1 / (total_1 + _EPSILON))) + + # Calculate asymptotic normal confidence intervals + std_norm_dist = Normal(loc=0., scale=1.) + z_value = std_norm_dist.quantile((1.0 - alpha) / 2.0) + if logit_transformation: + estimate = math_ops.log(auc / (1. - auc + _EPSILON)) + std_err = auc_std_err / (auc * (1. - auc + _EPSILON)) + transformed_auc_lower = estimate + (z_value * std_err) + transformed_auc_upper = estimate - (z_value * std_err) + def inverse_logit_transformation(x): + exp_negative = math_ops.exp(math_ops.negative(x)) + return 1. / (1. + exp_negative + _EPSILON) + + auc_lower = inverse_logit_transformation(transformed_auc_lower) + auc_upper = inverse_logit_transformation(transformed_auc_upper) + else: + estimate = auc + std_err = auc_std_err + auc_lower = estimate + (z_value * std_err) + auc_upper = estimate - (z_value * std_err) + + ## If estimate is 1 or 0, no variance is present so CI = 1 + ## n.b. This can be misleading, since number obs can just be too low. + lower = array_ops.where( + math_ops.logical_or( + math_ops.equal(auc, array_ops.ones_like(auc)), + math_ops.equal(auc, array_ops.zeros_like(auc))), + auc, auc_lower) + upper = array_ops.where( + math_ops.logical_or( + math_ops.equal(auc, array_ops.ones_like(auc)), + math_ops.equal(auc, array_ops.zeros_like(auc))), + auc, auc_upper) + + # If all the labels are the same, AUC isn't well-defined (but raising an + # exception seems excessive) so we return 0, otherwise we finish computing. + trivial_value = array_ops.constant(0.0) + + return AucData(*control_flow_ops.cond( + is_valid, lambda: [auc, lower, upper], lambda: [trivial_value]*3)) + + +def auc_with_confidence_intervals(labels, + predictions, + weights=None, + alpha=0.95, + logit_transformation=True, + metrics_collections=(), + updates_collections=(), + name=None): + """Computes the AUC and asymptotic normally distributed confidence interval. + + USAGE NOTE: this approach requires storing all of the predictions and labels + for a single evaluation in memory, so it may not be usable when the evaluation + batch size and/or the number of evaluation steps is very large. + + Computes the area under the ROC curve and its confidence interval using + placement values. This has the advantage of being resilient to the + distribution of predictions by aggregating across batches, accumulating labels + and predictions and performing the final calculation using all of the + concatenated values. + + Args: + labels: A `Tensor` of ground truth labels with the same shape as `labels` + and with values of 0 or 1 whose values are castable to `int64`. + predictions: A `Tensor` of predictions whose values are castable to + `float64`. Will be flattened into a 1-D `Tensor`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`. + alpha: Confidence interval level desired. + logit_transformation: A boolean value indicating whether the estimate should + be logit transformed prior to calculating the confidence interval. Doing + so enforces the restriction that the AUC should never be outside the + interval [0,1]. + metrics_collections: An optional iterable of collections that `auc` should + be added to. + updates_collections: An optional iterable of collections that `update_op` + should be added to. + name: An optional name for the variable_scope that contains the metric + variables. + + Returns: + auc: A 1-D `Tensor` containing the current area-under-curve, lower, and + upper confidence interval values. + update_op: An operation that concatenates the input labels and predictions + to the accumulated values. + + Raises: + ValueError: If `labels`, `predictions`, and `weights` have mismatched shapes + or if `alpha` isn't in the range (0,1). + """ + if not (alpha > 0 and alpha < 1): + raise ValueError('alpha must be between 0 and 1; currently %.02f' % alpha) + + if weights is None: + weights = array_ops.ones_like(predictions) + + with variable_scope.variable_scope( + name, + default_name='auc_with_confidence_intervals', + values=[labels, predictions, weights]): + + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, + labels=labels, + weights=weights) + + total_weight = math_ops.reduce_sum(weights) + + weights = array_ops.reshape(weights, [-1]) + predictions = array_ops.reshape( + math_ops.cast(predictions, dtypes.float64), [-1]) + labels = array_ops.reshape(math_ops.cast(labels, dtypes.int64), [-1]) + + with ops.control_dependencies([ + check_ops.assert_greater_equal( + labels, + array_ops.zeros_like(labels, dtypes.int64), + message='labels must be 0 or 1, at least one is <0'), + check_ops.assert_less_equal( + labels, + array_ops.ones_like(labels, dtypes.int64), + message='labels must be 0 or 1, at least one is >1'), + ]): + preds_accum, update_preds = streaming_concat( + predictions, name='concat_preds') + labels_accum, update_labels = streaming_concat(labels, + name='concat_labels') + weights_accum, update_weights = streaming_concat( + weights, name='concat_weights') + update_op_for_valid_case = control_flow_ops.group( + update_labels, update_preds, update_weights) + + # Only perform updates if this case is valid. + all_labels_positive_or_0 = math_ops.logical_and( + math_ops.equal(math_ops.reduce_min(labels), 0), + math_ops.equal(math_ops.reduce_max(labels), 1)) + sums_of_weights_at_least_1 = math_ops.greater_equal(total_weight, 1.0) + is_valid = math_ops.logical_and(all_labels_positive_or_0, + sums_of_weights_at_least_1) + + update_op = control_flow_ops.cond( + sums_of_weights_at_least_1, + lambda: update_op_for_valid_case, control_flow_ops.no_op) + + auc = _compute_placement_auc( + labels_accum, + preds_accum, + weights_accum, + alpha=alpha, + logit_transformation=logit_transformation, + is_valid=is_valid) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + if metrics_collections: + ops.add_to_collections(metrics_collections, auc) + return auc, update_op + + def precision_recall_at_equal_thresholds(labels, predictions, weights=None, @@ -3430,6 +3720,7 @@ def cohen_kappa(labels, __all__ = [ + 'auc_with_confidence_intervals', 'aggregate_metric_map', 'aggregate_metrics', 'cohen_kappa', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index e067f08babd9a900e876545d427c91e5ff808f04..b387f26c0195432fb972dac450d2919bdaa702a1 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -1802,9 +1802,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.54166603, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.54166603, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1816,9 +1816,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1830,9 +1830,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1865,9 +1865,9 @@ class StreamingAUCTest(test.TestCase): auc, update_op = metrics.streaming_auc(predictions, labels, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(1, sess.run(update_op), 6) + self.assertAlmostEqual(0.49999976, sess.run(update_op), 6) - self.assertAlmostEqual(1, auc.eval(), 6) + self.assertAlmostEqual(0.49999976, auc.eval(), 6) def testWithMultipleUpdates(self): num_samples = 1000 @@ -2128,6 +2128,205 @@ class StreamingDynamicAUCTest(test.TestCase): self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-5) +class AucWithConfidenceIntervalsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def _testResultsEqual(self, expected_dict, gotten_result): + """Tests that 2 results (dicts) represent the same data. + + Args: + expected_dict: A dictionary with keys that are the names of properties + of PrecisionRecallData and whose values are lists of floats. + gotten_result: A AucWithConfidenceIntervalData object. + """ + gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} + self.assertItemsEqual( + list(expected_dict.keys()), list(gotten_dict.keys())) + + for key, expected_values in expected_dict.items(): + self.assertAllClose(expected_values, gotten_dict[key]) + + def _testCase(self, predictions, labels, expected_result, weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type float32. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.int64) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) + gotten_result, update_op = ( + metric_ops.auc_with_confidence_intervals( + labels=labels_tensor, + predictions=predictions_tensor, + weights=weights_tensor)) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result) + + def testAucAllCorrect(self): + self._testCase( + predictions=[0., 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + expected_result={ + 'auc': 0.66666667, + 'lower': 0.27826795, + 'upper': 0.91208512, + }) + + def testAucUnorderedInput(self): + self._testCase( + predictions=[1.0, 0.6, 0., 0.3, 0.4, 0.2, 0.5, 0.3, 0.6, 0.8], + labels=[0, 1, 0, 1, 0, 0, 1, 0, 0, 1], + expected_result={ + 'auc': 0.66666667, + 'lower': 0.27826795, + 'upper': 0.91208512, + }) + + def testAucWithWeights(self): + self._testCase( + predictions=[0., 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[0, 0, 1, 0, 0, 1, 0, 1, 1, 0], + weights=[0.5, 0.6, 1.2, 1.5, 2.0, 2.0, 1.5, 1.2, 0.6, 0.5], + expected_result={ + 'auc': 0.65151515, + 'lower': 0.28918604, + 'upper': 0.89573906, + }) + + def testAucEqualOne(self): + self._testCase( + predictions=[0, 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + expected_result={ + 'auc': 1.0, + 'lower': 1.0, + 'upper': 1.0, + }) + + def testAucEqualZero(self): + self._testCase( + predictions=[0, 0.2, 0.3, 0.3, 0.4, 0.5, 0.6, 0.6, 0.8, 1.0], + labels=[1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + expected_result={ + 'auc': 0.0, + 'lower': 0.0, + 'upper': 0.0, + }) + + def testNonZeroOnePredictions(self): + self._testCase( + predictions=[2.5, -2.5, .5, -.5, 1], + labels=[1, 0, 1, 0, 0], + expected_result={ + 'auc': 0.83333333, + 'lower': 0.15229267, + 'upper': 0.99286517, + }) + + def testAllLabelsOnes(self): + self._testCase( + predictions=[1., 1., 1., 1., 1.], + labels=[1, 1, 1, 1, 1], + expected_result={ + 'auc': 0., + 'lower': 0., + 'upper': 0., + }) + + def testAllLabelsZeros(self): + self._testCase( + predictions=[0., 0., 0., 0., 0.], + labels=[0, 0, 0, 0, 0], + expected_result={ + 'auc': 0., + 'lower': 0., + 'upper': 0., + }) + + def testWeightSumLessThanOneAll(self): + self._testCase( + predictions=[1., 1., 0., 1., 0., 0.], + labels=[1, 1, 1, 0, 0, 0], + weights=[0.1, 0.1, 0.1, 0.1, 0.1, 0.1], + expected_result={ + 'auc': 0., + 'lower': 0., + 'upper': 0., + }) + + def testWithMultipleUpdates(self): + batch_size = 50 + num_batches = 100 + labels = np.array([]) + predictions = np.array([]) + tf_labels = variables.Variable(array_ops.ones(batch_size, dtypes_lib.int32), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.int32) + tf_predictions = variables.Variable( + array_ops.ones(batch_size), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + dtype=dtypes_lib.float32) + auc, update_op = metrics.auc_with_confidence_intervals(tf_labels, + tf_predictions) + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + for _ in xrange(num_batches): + new_labels = np.random.randint(0, 2, size=batch_size) + noise = np.random.normal(0.0, scale=0.2, size=batch_size) + new_predictions = 0.4 + 0.2 * new_labels + noise + labels = np.concatenate([labels, new_labels]) + predictions = np.concatenate([predictions, new_predictions]) + sess.run(tf_labels.assign(new_labels)) + sess.run(tf_predictions.assign(new_predictions)) + sess.run(update_op) + expected_auc = _np_auc(predictions, labels) + self.assertAllClose(expected_auc, auc.auc.eval()) + + def testExceptionOnFloatLabels(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) + labels = constant_op.constant([0.7, 0, 1, 0, 1]) + _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) + sess.run(variables.local_variables_initializer()) + self.assertRaises(TypeError, sess.run(update_op)) + + def testExceptionOnGreaterThanOneLabel(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) + labels = constant_op.constant([2, 1, 0, 1, 0]) + _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) + sess.run(variables.local_variables_initializer()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + '.*labels must be 0 or 1, at least one is >1.*'): + sess.run(update_op) + + def testExceptionOnNegativeLabel(self): + with self.test_session() as sess: + predictions = constant_op.constant([1, 0.5, 0, 1, 0], dtypes_lib.float32) + labels = constant_op.constant([1, 0, -1, 1, 0]) + _, update_op = metrics.auc_with_confidence_intervals(labels, predictions) + sess.run(variables.local_variables_initializer()) + with self.assertRaisesRegexp( + errors_impl.InvalidArgumentError, + '.*labels must be 0 or 1, at least one is <0.*'): + sess.run(update_op) + + class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): def setUp(self): @@ -6689,7 +6888,8 @@ class CohenKappaTest(test.TestCase): # [[0, 25, 0], # [0, 0, 25], # [25, 0, 0]] - # Calculated by v0.19: sklearn.metrics.cohen_kappa_score(labels, predictions) + # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( + # labels, predictions) expect = -0.333333333333 with self.test_session() as sess: @@ -6748,7 +6948,8 @@ class CohenKappaTest(test.TestCase): weights_t: weights[batch_start:batch_end] }) # Calculated by v0.19: sklearn.metrics.cohen_kappa_score( - # labels_np, predictions_np, sample_weight=weights_np) + # labels_np, predictions_np, + # sample_weight=weights_np) expect = 0.289965397924 self.assertAlmostEqual(expect, kappa.eval(), 5) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py index d68ad23d65500cc2348459cdc53030c2ea08373a..9ce50bfe1054072b315adecb87f1ba729dfe0d83 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer.py @@ -83,7 +83,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): self._optimizer = opt self._ema = moving_averages.ExponentialMovingAverage( average_decay, num_updates=num_updates) - self._variable_map = None + self._swapped_variable_name_map = None self._sequential_update = sequential_update def compute_gradients(self, *args, **kwargs): @@ -93,7 +93,7 @@ class MovingAverageOptimizer(optimizer.Optimizer): train_op = self._optimizer.apply_gradients( grads_and_vars, global_step=global_step, name=name) var_list = [x[1] for x in grads_and_vars if x[0] is not None] - self._variable_map = {} + self._swapped_variable_name_map = {} if self._sequential_update: with ops.control_dependencies([train_op]): ma_op = self._ema.apply(var_list) @@ -102,9 +102,9 @@ class MovingAverageOptimizer(optimizer.Optimizer): for v in var_list: v_avg = self._ema.average(v) - self._variable_map[v.op.name] = v_avg - self._variable_map[v_avg.op.name] = v - return control_flow_ops.group(train_op, ma_op, name="train_with_avg") + self._swapped_variable_name_map[v.op.name] = v_avg.op.name + self._swapped_variable_name_map[v_avg.op.name] = v.op.name + return control_flow_ops.group(train_op, ma_op, name='train_with_avg') def swapping_saver(self, var_list=None, name='swapping_saver', **kwargs): """Create a saver swapping moving averages and variables. @@ -129,22 +129,45 @@ class MovingAverageOptimizer(optimizer.Optimizer): Raises: RuntimeError: If apply_gradients or minimize has not been called before. + ValueError: If var_list is provided and contains some variables but not + their moving average counterpart. """ - if self._variable_map is None: + if self._swapped_variable_name_map is None: raise RuntimeError('Must call apply_gradients or minimize before ' 'creating the swapping_saver') if var_list is None: var_list = variables.global_variables() if not isinstance(var_list, dict): var_list = saver.BaseSaverBuilder.OpListToDict(var_list) + + # OpListToDict converts variables to tensors. We make sure we can get + # the unique variable name for normal and resource vaiables. + def get_v_name(tensor): + if tensor.op.type == 'ReadVariableOp': + return tensor.op.inputs[0].op.name + else: + return tensor.op.name + + v_name_to_tensor = {} + for tensor in six.itervalues(var_list): + v_name = get_v_name(tensor) + v_name_to_tensor[v_name] = tensor + # Now swap variables and moving averages swapped_var_list = {} - for k, v in six.iteritems(var_list): - v_swap = self._variable_map.get(v.op.name, None) - if v_swap: - swapped_var_list[k] = v_swap - else: - swapped_var_list[k] = v + for k, tensor in six.iteritems(var_list): + v_name = get_v_name(tensor) + swapped_v_name = self._swapped_variable_name_map.get(v_name, None) + tensor_to_save = tensor + if swapped_v_name is not None: + if swapped_v_name in v_name_to_tensor: + tensor_to_save = v_name_to_tensor[swapped_v_name] + else: + raise ValueError( + ('Variable to swap %s is not part of variables to save. ' + 'This breaks MovingAverageOptimizer.') % swapped_v_name) + swapped_var_list[k] = tensor_to_save + # Build the swapping saver. return saver.Saver(swapped_var_list, name=name, **kwargs) diff --git a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py index 60929add198f2e69b5acc2eb5516dafc82b1f3ba..85e3e8d3791f2331ed249c0b7f67a3dbde4fca08 100644 --- a/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py +++ b/tensorflow/contrib/opt/python/training/moving_average_optimizer_test.py @@ -24,6 +24,10 @@ import six from tensorflow.contrib.opt.python.training import moving_average_optimizer from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import gradient_descent @@ -33,13 +37,26 @@ from tensorflow.python.training import saver class MovingAverageOptimizerTest(test.TestCase): def testRun(self): + self._helpTestRun(use_resource=False) + + def testRunUseResource(self): + # Test that MovingAverageOptimizer works with resource variables. + self._helpTestRun(use_resource=True) + + def _helpTestRun(self, use_resource=False): for sequential_update in [True, False]: for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: - with self.test_session() as sess: + with self.test_session(graph=ops.Graph()) as sess: orig_val0 = [1.0, 2.0] orig_val1 = [3.0, 4.0] - var0 = variables.Variable(orig_val0, name='var0', dtype=dtype) - var1 = variables.Variable(orig_val1, name='var1', dtype=dtype) + var0 = variable_scope.get_variable( + 'var0', + initializer=constant_op.constant(orig_val0, dtype=dtype), + use_resource=use_resource) + var1 = variable_scope.get_variable( + 'var1', + initializer=constant_op.constant(orig_val1, dtype=dtype), + use_resource=use_resource) grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) @@ -52,22 +69,63 @@ class MovingAverageOptimizerTest(test.TestCase): save_path = os.path.join(save_dir, 'model') update = opt.apply_gradients( list(six.moves.zip([grads0, grads1], [var0, var1]))) + global_vars = variables.global_variables() + ema_var0 = [ + v for v in global_vars + if v.op.name == 'var0/ExponentialMovingAverage' + ][0] + ema_var1 = [ + v for v in global_vars + if v.op.name == 'var1/ExponentialMovingAverage' + ][0] + perturb = control_flow_ops.group([ + state_ops.assign_add(var0, [1.0, 1.0]), + state_ops.assign_add(var1, [2.0, 2.0]), + state_ops.assign_add(ema_var0, [3.0, 3.0]), + state_ops.assign_add(ema_var1, [4.0, 4.0]) + ]) + + # Test taht saver with missing ema variables will fail. + with self.assertRaisesRegexp(ValueError, r'Variable to swap'): + opt.swapping_saver(var_list=[var0]) + train_saver = opt.swapping_saver() + train_saver_subset = opt.swapping_saver(var_list=[var0, ema_var0]) inference_saver = saver.Saver() variables.global_variables_initializer().run() # Step 1. update.run() - val0 = var0.eval() - val1 = var1.eval() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) # Test that the swapping saver save/restore operation is identity. train_saver.save(sess, save_path) train_saver.restore(sess, save_path) - val0 = var0.eval() - val1 = var1.eval() self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) self.assertAllCloseAccordingToType([2.98, 3.98], var1.eval()) + if sequential_update: + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([2.99, 3.99], ema_var1.eval()) + # Test that the subset saver saves the EMA variable as well. + if sequential_update: + subset_save_path = save_path + '_subset' + train_saver_subset.save(sess, subset_save_path) + perturb.run() + self.assertAllCloseAccordingToType([1.8, 2.8], var0.eval()) + self.assertAllCloseAccordingToType([3.9, 4.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restoring should only restore var0 and ema_var0. + train_saver_subset.restore(sess, subset_save_path) + self.assertAllCloseAccordingToType([0.8, 1.8], var0.eval()) + self.assertAllCloseAccordingToType([0.9, 1.9], ema_var0.eval()) + self.assertAllCloseAccordingToType([4.98, 5.98], var1.eval()) + self.assertAllCloseAccordingToType([6.99, 7.99], ema_var1.eval()) + # Restore back to previou state. + train_saver.restore(sess, save_path) + # If updates are parallel, this is not always true after the 1st step. if sequential_update: # Test that the normal saver will have the averaged variables. diff --git a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer.py b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer.py index 74036082f0ca2bae23b30deb1b1986befd6601d8..3c0b8394be51e8744b5461a00a99ead5e45d90b2 100644 --- a/tensorflow/contrib/opt/python/training/variable_clipping_optimizer.py +++ b/tensorflow/contrib/opt/python/training/variable_clipping_optimizer.py @@ -109,7 +109,7 @@ class VariableClippingOptimizer(optimizer.Optimizer): def _clip_dense(self, var): with self._maybe_colocate_with(var): - updated_var_value = var._ref() # pylint: disable=protected-access + updated_var_value = var.read_value() normalized_var = clip_ops.clip_by_norm( updated_var_value, self._max_norm, self._vars_to_clip_dims[var]) delta = updated_var_value - normalized_var diff --git a/tensorflow/contrib/py2tf/__init__.py b/tensorflow/contrib/py2tf/__init__.py index 379fa7fd5c2a22b5b16a21cca8c2ea8afdcaeefa..6531183cb59af774299eb767cce111d2ec6f32b4 100644 --- a/tensorflow/contrib/py2tf/__init__.py +++ b/tensorflow/contrib/py2tf/__init__.py @@ -23,6 +23,7 @@ from __future__ import print_function from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.impl.api import convert +from tensorflow.contrib.py2tf.impl.api import converted_call from tensorflow.contrib.py2tf.impl.api import graph_ready from tensorflow.contrib.py2tf.impl.api import to_code from tensorflow.contrib.py2tf.impl.api import to_graph @@ -30,7 +31,8 @@ from tensorflow.contrib.py2tf.pyct.transformer import PyFlowParseError from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'to_graph', 'to_code', 'convert', 'graph_ready', 'utils', 'PyFlowParseError' + 'to_graph', 'to_code', 'convert', 'graph_ready', 'converted_call', 'utils', + 'PyFlowParseError' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/py2tf/converters/BUILD b/tensorflow/contrib/py2tf/converters/BUILD index 2e67ba221b01d9ce0670c7a958e1014df725c7f2..78f46bc05f2e6f4c5e0b6868ce93dbdeb8c7625a 100644 --- a/tensorflow/contrib/py2tf/converters/BUILD +++ b/tensorflow/contrib/py2tf/converters/BUILD @@ -18,14 +18,16 @@ py_library( name = "converters", srcs = [ "asserts.py", - "break_canonicalization.py", + "break_statements.py", "builtin_functions.py", "call_trees.py", - "continue_canonicalization.py", + "continue_statements.py", "control_flow.py", "decorators.py", - "for_canonicalization.py", + "for_loops.py", + "list_comprehension.py", "logical_expressions.py", + "name_scopes.py", "side_effect_guards.py", ], srcs_version = "PY2AND3", @@ -44,6 +46,7 @@ py_library( visibility = ["//tensorflow:__subpackages__"], deps = [ ":converters", + "//tensorflow/contrib/py2tf/pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis", "//tensorflow/contrib/py2tf/utils", "@gast_archive//:gast", @@ -57,18 +60,16 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( - name = "break_canonicalization_test", - srcs = ["break_canonicalization_test.py"], + name = "break_statements_test", + srcs = ["break_statements_test.py"], srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -79,7 +80,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -90,18 +90,17 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", + "//tensorflow/contrib/py2tf/impl", "//tensorflow/python:client_testlib", ], ) py_test( - name = "continue_canonicalization_test", - srcs = ["continue_canonicalization_test.py"], + name = "continue_statements_test", + srcs = ["continue_statements_test.py"], srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -112,7 +111,6 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -123,15 +121,23 @@ py_test( srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) py_test( - name = "for_canonicalization_test", - srcs = ["for_canonicalization_test.py"], + name = "for_loops_test", + srcs = ["for_loops_test.py"], srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "name_scopes_test", + srcs = ["name_scopes_test.py"], deps = [ ":test_lib", "//tensorflow/contrib/py2tf/pyct", @@ -139,13 +145,22 @@ py_test( ], ) +py_test( + name = "list_comprehension_test", + srcs = ["list_comprehension_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":test_lib", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "logical_expressions_test", srcs = ["logical_expressions_test.py"], srcs_version = "PY2AND3", deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) @@ -154,9 +169,13 @@ py_test( name = "side_effect_guards_test", srcs = ["side_effect_guards_test.py"], srcs_version = "PY2AND3", + tags = [ + # TODO(mdan): Fix. + "flaky", + "notap", + ], deps = [ ":test_lib", - "//tensorflow/contrib/py2tf/pyct", "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization.py b/tensorflow/contrib/py2tf/converters/break_statements.py similarity index 100% rename from tensorflow/contrib/py2tf/converters/break_canonicalization.py rename to tensorflow/contrib/py2tf/converters/break_statements.py diff --git a/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/break_statements_test.py similarity index 91% rename from tensorflow/contrib/py2tf/converters/break_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/break_statements_test.py index 2243398100880483d40a1ba7451a229e0dbe115b..095fcdff07d44ecc6b9bb7f8d3e2c7c43df72a02 100644 --- a/tensorflow/contrib/py2tf/converters/break_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/break_statements_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for break_canonicalization module.""" +"""Tests for break_statements module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import break_canonicalization +from tensorflow.contrib.py2tf.converters import break_statements from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.python.platform import test @@ -37,7 +37,7 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): return v node = self.parse_and_analyze(test_fn, {}) - node = break_canonicalization.transform(node, self.ctx) + node = break_statements.transform(node, self.ctx) with self.compiled(node) as result: self.assertEqual(test_fn(0), result.test_fn(0)) @@ -69,7 +69,7 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): return v node = self.parse_and_analyze(test_fn, {}) - node = break_canonicalization.transform(node, self.ctx) + node = break_statements.transform(node, self.ctx) with self.compiled(node) as result: # The break is incompletely canonicalized. Everything is in place, but @@ -98,7 +98,7 @@ class BreakCanonicalizationTest(converter_test_base.TestCase): return v, u, w node = self.parse_and_analyze(test_fn, {}) - node = break_canonicalization.transform(node, self.ctx) + node = break_statements.transform(node, self.ctx) with self.compiled(node) as result: self.assertEqual(test_fn(0), result.test_fn(0)) diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions.py b/tensorflow/contrib/py2tf/converters/builtin_functions.py index 310681dd016ca94bf2b28d27a4968cc0c10a5842..b5aa9756da6a139e542e9a0ead86cf4cc8207449 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions.py @@ -25,36 +25,37 @@ from tensorflow.contrib.py2tf.pyct import transformer class BuiltinFunctionTransformer(transformer.Base): - """Handles builtin functions and canonicalizes old-style print statement. + """Handles builtin functions. This transformer only covers functions that are translated into a TF equivalent, like `len`. - Note that the `print` statement is converted to a function call here, but - wrapping the print function to a `py_func` is done by `call_trees` as a - generic uncompilable function wrap. """ - # TODO(mdan): Handle print entirely in here. - # Fully handling print here makes sense especially since we're considering - # using tf.Print instead. - def __init__(self, context): super(BuiltinFunctionTransformer, self).__init__(context) - def _convert_len(self, node): + # pylint:disable=invalid-name + + def _convert_builtin(self, node): template = """ - tf.shape(args)[0] + py2tf_utils.dynamic_builtin(func, args) """ - new_call = templates.replace(template, args=node.args)[0].value - return new_call + return templates.replace(template, func=node.func, args=node.args)[0].value - # pylint:disable=invalid-name + def _convert_print(self, node): + template = """ + py2tf_utils.dynamic_print(args) + """ + return templates.replace(template, args=node.args)[0].value def visit_Call(self, node): self.generic_visit(node) # TODO(mdan): This won't work if the function was hidden. - if isinstance(node.func, gast.Name) and node.func.id == 'len': - return self._convert_len(node) + if isinstance(node.func, gast.Name) and node.func.id in ('len',): + return self._convert_builtin(node) + # Print needs to be handled separately because it can be read as statement. + if isinstance(node.func, gast.Name) and node.func.id == 'print': + return self._convert_print(node) return node def visit_Print(self, node): @@ -66,7 +67,8 @@ class BuiltinFunctionTransformer(transformer.Base): template = """ fname(args) """ - return templates.replace(template, fname='print', args=args) + function_call = templates.replace(template, fname='print', args=args)[0] + return self.visit(function_call) # pylint:enable=invalid-name diff --git a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py index 983d1ffc03466ab3e2148e8cdf6e54050b9d3947..eb60a1d8ae2b56907df8f3ffafe7604883cfc2a9 100644 --- a/tensorflow/contrib/py2tf/converters/builtin_functions_test.py +++ b/tensorflow/contrib/py2tf/converters/builtin_functions_test.py @@ -26,6 +26,8 @@ from tensorflow.contrib.py2tf.converters import builtin_functions from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -45,7 +47,9 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): sess.run( result.test_fn(constant_op.constant([0, 0, 0])))) - def test_print(self): + self.assertEqual(3, result.test_fn([0, 0, 0])) + + def test_print_with_op(self): def test_fn(a): print(a) @@ -53,16 +57,41 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - with self.compiled(node) as result: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn('a') - self.assertEqual(out_capturer.getvalue(), 'a\n') - finally: - sys.stdout = sys.__stdout__ + # Note: it's relevant not to include script_ops.py_func here, to verify + # that tf.Print is used. + with self.compiled(node, logging_ops.Print) as result: + with self.test_session() as sess: + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + result.test_fn('a') + sess.run(sess.graph.get_operations()) + self.assertEqual(out_capturer.getvalue(), 'a\n') + finally: + sys.stdout = sys.__stdout__ + + def test_print_with_op_multiple_values(self): + + def test_fn(a, b): + print(a, b) + + node = self.parse_and_analyze(test_fn, {'print': print}) + node = builtin_functions.transform(node, self.ctx) + + # Note: it's relevant not to include script_ops.py_func here, to verify + # that tf.Print is used. + with self.compiled(node, logging_ops.Print) as result: + with self.test_session() as sess: + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + result.test_fn('a', 1) + sess.run(sess.graph.get_operations()) + self.assertEqual(out_capturer.getvalue(), 'a 1\n') + finally: + sys.stdout = sys.__stdout__ - def test_print_tuple(self): + def test_print_with_py_func(self): def test_fn(a, b, c): print(a, b, c) @@ -70,18 +99,18 @@ class BuiltinFunctionsTest(converter_test_base.TestCase): node = self.parse_and_analyze(test_fn, {'print': print}) node = builtin_functions.transform(node, self.ctx) - with self.compiled(node) as result: - try: - out_capturer = six.StringIO() - sys.stdout = out_capturer - result.test_fn('a', 1, [2, 3]) - # It appears that the print output looks odd only under Python 2. - if six.PY2: - self.assertEqual(out_capturer.getvalue(), "('a', 1, [2, 3])\n") - else: + # Note: it's relevant not to include logging_ops.Print here, to verify + # that py_func is used. + with self.compiled(node, script_ops.py_func) as result: + with self.test_session() as sess: + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + result.test_fn('a', 1, [2, 3]) + sess.run(sess.graph.get_operations()) self.assertEqual(out_capturer.getvalue(), 'a 1 [2, 3]\n') - finally: - sys.stdout = sys.__stdout__ + finally: + sys.stdout = sys.__stdout__ if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/call_trees.py b/tensorflow/contrib/py2tf/converters/call_trees.py index 60096d5a7b7c6eea9f3ade75ba492d54732b3550..f18f9f608671f90f08f8607f752fc76713a42874 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees.py +++ b/tensorflow/contrib/py2tf/converters/call_trees.py @@ -27,11 +27,10 @@ import types import gast from tensorflow.contrib.py2tf.pyct import anno +from tensorflow.contrib.py2tf.pyct import inspect_utils from tensorflow.contrib.py2tf.pyct import parser -from tensorflow.contrib.py2tf.pyct import qual_names from tensorflow.contrib.py2tf.pyct import templates from tensorflow.contrib.py2tf.pyct import transformer -from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno from tensorflow.python.util import tf_inspect @@ -74,9 +73,8 @@ class CallTreeTransformer(transformer.Base): self.uncompiled_modules = uncompiled_modules self.nocompile_decorators = nocompile_decorators - # pylint:disable=invalid-name - def _resolve_name(self, node): + """Used to resolve decorator info.""" if isinstance(node, gast.Call): return self._resolve_name(node.func) if isinstance(node, gast.Name): @@ -101,7 +99,13 @@ class CallTreeTransformer(transformer.Base): (owner_type, node.attr)) return None + def _function_is_compilable(self, target_entity): + """Determines whether an entity can be compiled at all.""" + # TODO(mdan): This is just a placeholder. Implement. + return not isinstance(target_entity, types.BuiltinFunctionType) + def _should_compile(self, node, fqn): + """Determines whether an entity should be compiled in the context.""" for i in range(1, len(fqn)): if fqn[:i] in self.uncompiled_modules: return False @@ -143,33 +147,6 @@ class CallTreeTransformer(transformer.Base): return True - def _determine_function_owner(self, m): - # TODO(mdan): The parent type should be known at analysis. Use that instead. - if hasattr(m, 'im_class'): # Python 2 - return m.im_class - if hasattr(m, '__qualname__'): # Python 3 - # Object attributes: should be bound to "self". - if hasattr(m, '__self__'): - return type(m.__self__) - - # Class attributes: should have the owner name in their namespace. - qn = m.__qualname__.split('.') - if len(qn) < 2: - return None - owner_name, func_name = qn[-2:] - if func_name != m.__name__: - raise ValueError('Inconsistent names detected ' - '(__qualname__[1] = "%s", __name__ = "%s") for %s.' % - (func_name, m.__name__, m)) - if owner_name == '': - return None - if owner_name not in self.context.namespace: - raise ValueError( - 'Could not resolve name "%s" while analyzing %s. Namespace:\n%s' % - (owner_name, m, self.context.namespace)) - return self.context.namespace[owner_name] - return None - def _rename_compilable_function(self, node): assert anno.hasanno(node.func, 'live_val') assert anno.hasanno(node.func, 'fqn') @@ -184,7 +161,11 @@ class CallTreeTransformer(transformer.Base): target_fqn, live_entity=target_entity) do_rename = True else: - owner_type = self._determine_function_owner(target_entity) + if anno.hasanno(node.func, 'parent_type'): + owner_type = anno.getanno(node.func, 'parent_type') + else: + # Fallback - not reliable. + owner_type = inspect_utils.getmethodclass(target_entity) new_name, do_rename = self.context.namer.compiled_function_name( target_fqn, live_entity=target_entity, owner_type=owner_type) @@ -198,40 +179,38 @@ class CallTreeTransformer(transformer.Base): return node def _wrap_to_py_func_no_return(self, node): - func_qn = anno.getanno(node.func, anno.Basic.QN) - args_scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) - wrapper_name = self.context.namer.new_symbol(func_qn.ssf(), - args_scope.referenced) - wrapper_args = [] - for arg in node.args: - if anno.hasanno(arg, anno.Basic.QN): - arg_qn = anno.getanno(arg, anno.Basic.QN) - else: - arg_qn = qual_names.QN('arg') - wrapper_args.append( - self.context.namer.new_symbol(arg_qn.ssf(), args_scope.referenced)) # TODO(mdan): Properly handle varargs, kwargs, etc. - # TODO(mdan): This is best handled as a dynamic dispatch. - # That way we can separate tensors from non-tensor args. template = """ - def wrapper(wrapper_args): - call(wrapper_args) - return 1 - tf.py_func(wrapper, original_args, [tf.int64]) + py2tf_utils.wrap_py_func(func, None, (original_args,), True) """ - wrapper_def, call_expr = templates.replace( - template, - call=node.func, - wrapper=wrapper_name, - original_args=gast.List(elts=node.args, ctx=None), - wrapper_args=wrapper_args) - anno.setanno(wrapper_def, anno.Basic.SKIP_PROCESSING, True) - - return (wrapper_def, call_expr) + return templates.replace(template, func=node.func, original_args=node.args) + + def _converted_call(self, node): + """Inlines a dynamic conversion for a dynamic function.""" + # TODO(mdan): Pass information on the statically compiled functions. + # Having access to the statically compiled functions can help avoid + # unnecessary compilation. + # For example, this would lead to function `a` being compiled twice: + # + # def a(): + # v = b + # b() + # def b(): + # a() + # + # This is really a problem with recursive calls, which currently can + # only be gated by a static condition, and should be rare. + # TODO(mdan): It probably makes sense to use dynamic conversion every time. + # Before we could convert all the time though, we'd need a reasonable + # caching mechanism. + template = """ + py2tf_api.converted_call(func, True, False, {}, original_args) + """ + call_expr = templates.replace( + template, func=node.func, original_args=node.args) + return call_expr[0].value - def _function_is_compilable(self, target_entity): - # TODO(mdan): This is just a placeholder. Implement. - return not isinstance(target_entity, types.BuiltinFunctionType) + # pylint:disable=invalid-name def visit_Expr(self, node): if isinstance(node.value, gast.Call): @@ -272,9 +251,9 @@ class CallTreeTransformer(transformer.Base): raise NotImplementedError('py_func with return values') else: if self.context.recursive: - raise NotImplementedError('Could not resolve target function.') + node = self._converted_call(node) else: - # TODO(mdan): Double check. Is this reachable code? + # Unresolved functions are allowed in non-recursive mode. pass return node diff --git a/tensorflow/contrib/py2tf/converters/call_trees_test.py b/tensorflow/contrib/py2tf/converters/call_trees_test.py index 18a5c1e6e35f49a01649c17e6cd5647389e1f526..d482a9ef7897388839bbf8f9e4bfc5839d42b2d7 100644 --- a/tensorflow/contrib/py2tf/converters/call_trees_test.py +++ b/tensorflow/contrib/py2tf/converters/call_trees_test.py @@ -47,6 +47,21 @@ class CallTreesTest(converter_test_base.TestCase): result.renamed_test_fn_1 = renamed_test_fn_1 self.assertEquals(3, result.test_fn_2(1)) + def test_dynamic_function(self): + + def test_fn_1(): + raise ValueError('This should be masked by the mock.') + + def test_fn_2(f): + return f() + 3 + + node = self.parse_and_analyze(test_fn_2, {}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + # 10 = 7 (from the mock) + 3 (from test_fn_2) + self.assertEquals(10, result.test_fn_2(test_fn_1)) + def test_simple_methods(self): class TestClass(object): @@ -59,6 +74,7 @@ class CallTreesTest(converter_test_base.TestCase): node = self.parse_and_analyze( TestClass.test_fn_2, {'TestClass': TestClass}, + namer=converter_test_base.FakeNoRenameNamer(), arg_types={'self': (TestClass.__name__, TestClass)}) node = call_trees.transform(node, self.ctx, (), ()) @@ -66,6 +82,29 @@ class CallTreesTest(converter_test_base.TestCase): tc = TestClass() self.assertEquals(3, result.test_fn_2(tc, 1)) + def test_py_func_wrap_no_retval(self): + + def test_fn(a): + setattr(a, 'foo', 'bar') + + node = self.parse_and_analyze(test_fn, {'setattr': setattr}) + node = call_trees.transform(node, self.ctx, (), ()) + + with self.compiled(node) as result: + with self.test_session() as sess: + # The function has no return value, so we do some tricks to grab the + # generated py_func node and ensure its effect only happens at graph + # execution. + + class Dummy(object): + pass + + a = Dummy() + result.test_fn(a) + self.assertFalse(hasattr(a, 'foo')) + sess.run(sess.graph.get_operations()[0]) + self.assertEquals('bar', a.foo) + def test_uncompiled_modules(self): def test_fn(a): diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization.py b/tensorflow/contrib/py2tf/converters/continue_statements.py similarity index 100% rename from tensorflow/contrib/py2tf/converters/continue_canonicalization.py rename to tensorflow/contrib/py2tf/converters/continue_statements.py diff --git a/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/continue_statements_test.py similarity index 89% rename from tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/continue_statements_test.py index 2a0fb2d88b54114d558f1ea4cf9b1dc53b21e5cf..a598dcd1aed29478b7e3fe27e3c1b20010247dd9 100644 --- a/tensorflow/contrib/py2tf/converters/continue_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/continue_statements_test.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for continue_canonicalization module.""" +"""Tests for continue_statements module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.converters import continue_canonicalization +from tensorflow.contrib.py2tf.converters import continue_statements from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.python.platform import test @@ -37,7 +37,7 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): return v node = self.parse_and_analyze(test_fn, {}) - node = continue_canonicalization.transform(node, self.ctx) + node = continue_statements.transform(node, self.ctx) with self.compiled(node) as result: self.assertEqual(test_fn(0), result.test_fn(0)) @@ -58,7 +58,7 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): return v node = self.parse_and_analyze(test_fn, {}) - node = continue_canonicalization.transform(node, self.ctx) + node = continue_statements.transform(node, self.ctx) with self.compiled(node) as result: self.assertEqual(test_fn([]), result.test_fn([])) @@ -84,7 +84,7 @@ class ContinueCanonicalizationTest(converter_test_base.TestCase): return v, u, w node = self.parse_and_analyze(test_fn, {}) - node = continue_canonicalization.transform(node, self.ctx) + node = continue_statements.transform(node, self.ctx) with self.compiled(node) as result: self.assertEqual(test_fn(0), result.test_fn(0)) diff --git a/tensorflow/contrib/py2tf/converters/converter_test_base.py b/tensorflow/contrib/py2tf/converters/converter_test_base.py index 67747183dd323a799a04943ce4c7fe8c4093d002..1f98d8469c1b3032fe6babb5a63dde1747027f21 100644 --- a/tensorflow/contrib/py2tf/converters/converter_test_base.py +++ b/tensorflow/contrib/py2tf/converters/converter_test_base.py @@ -25,6 +25,7 @@ from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.contrib.py2tf.pyct import context from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import pretty_printer from tensorflow.contrib.py2tf.pyct import qual_names from tensorflow.contrib.py2tf.pyct.static_analysis import activity from tensorflow.contrib.py2tf.pyct.static_analysis import live_values @@ -52,26 +53,43 @@ class FakeNamer(object): return ('renamed_%s' % '_'.join(original_fqn)), True +class FakeNoRenameNamer(FakeNamer): + + def compiled_function_name(self, original_fqn, **_): + return str(original_fqn), False + + class TestCase(test.TestCase): """Base class for unit tests in this module. Contains relevant utilities.""" @contextlib.contextmanager def compiled(self, node, *symbols): - source = '' + source = None + + self.dynamic_calls = [] + def converted_call(*args): + """Mock version of api.converted_call.""" + self.dynamic_calls.append(args) + return 7 + try: result, source = compiler.ast_to_object(node) - result.tf = self.make_fake_tf(*symbols) + result.tf = self.make_fake_mod('fake_tf', *symbols) result.py2tf_utils = utils + result.py2tf_api = self.make_fake_mod('fake_api', converted_call) yield result except Exception: # pylint:disable=broad-except - print('Offending compiled code:\n%s' % source) + if source is None: + print('Offending AST:\n%s' % pretty_printer.fmt(node, color=False)) + else: + print('Offending compiled code:\n%s' % source) raise - def make_fake_tf(self, *symbols): - fake_tf = imp.new_module('fake_tf') + def make_fake_mod(self, name, *symbols): + fake_mod = imp.new_module(name) for s in symbols: - setattr(fake_tf, s.__name__, s) - return fake_tf + setattr(fake_mod, s.__name__, s) + return fake_mod def attach_namespace(self, module, **ns): for k, v in ns.items(): @@ -83,6 +101,7 @@ class TestCase(test.TestCase): namer=None, arg_types=None, include_type_analysis=True, + owner_type=None, recursive=True): node, source = parser.parse_entity(test_fn) ctx = context.EntityContext( @@ -92,6 +111,7 @@ class TestCase(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, + owner_type=owner_type, recursive=recursive) node = qual_names.resolve(node) node = activity.resolve(node, ctx) diff --git a/tensorflow/contrib/py2tf/converters/decorators.py b/tensorflow/contrib/py2tf/converters/decorators.py index 3f620c1cd2d9b75f82410754a7e812e13eabe3ae..68bf241ef33292f0581ccb3c44f313f853c92ba7 100644 --- a/tensorflow/contrib/py2tf/converters/decorators.py +++ b/tensorflow/contrib/py2tf/converters/decorators.py @@ -33,6 +33,7 @@ class DecoratorsTransformer(gast.NodeTransformer): def __init__(self, remove_decorators): self.remove_decorators = remove_decorators + self.additional_dependencies = set() # pylint:disable=invalid-name @@ -44,13 +45,38 @@ class DecoratorsTransformer(gast.NodeTransformer): dec_func = dec.func else: dec_func = dec + + # Special cases. + # TODO(mdan): Is there any way we can treat these more generically? + # We may want to forego using decorators altogether if we can't + # properly support them. + if isinstance(dec_func, gast.Name) and dec_func.id in ('classmethod',): + # Assumption: decorators are only visible in the AST when converting + # a function inline (via another decorator). + # In that case, the converted function is no longer part of the + # original object that it was declared into. + # This is currently verified by tests. + continue + if not anno.hasanno(dec_func, 'live_val'): raise ValueError( 'Could not resolve decorator: %s' % pretty_printer.fmt(dec_func)) + dec_value = anno.getanno(dec_func, 'live_val') if dec_value not in self.remove_decorators: - kept_decorators.append(dec) - node.decorator_list = kept_decorators + kept_decorators.append((dec, dec_value)) + + for _, dec_value in kept_decorators: + if dec_value.__module__ == '__main__': + raise ValueError( + 'decorator "%s" was not allowed because it is declared ' + 'in the module "%s". To fix this, declare it in a separate ' + 'module that we can import it from.' % (dec_value, + dec_value.__module__)) + else: + self.additional_dependencies.add(dec_value) + + node.decorator_list = [dec for dec, _ in kept_decorators] return node # pylint:enable=invalid-name @@ -59,4 +85,4 @@ class DecoratorsTransformer(gast.NodeTransformer): def transform(node, remove_decorators): transformer = DecoratorsTransformer(remove_decorators) node = transformer.visit(node) - return node + return node, transformer.additional_dependencies diff --git a/tensorflow/contrib/py2tf/converters/decorators_test.py b/tensorflow/contrib/py2tf/converters/decorators_test.py index 402fa0dda28e696f70d0354ca4abf3a6c83506d9..c75e5461746f27d14a54b7ac06e7f77d868372c8 100644 --- a/tensorflow/contrib/py2tf/converters/decorators_test.py +++ b/tensorflow/contrib/py2tf/converters/decorators_test.py @@ -18,84 +18,121 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import textwrap +from functools import wraps from tensorflow.contrib.py2tf.converters import converter_test_base from tensorflow.contrib.py2tf.converters import decorators from tensorflow.contrib.py2tf.pyct import compiler from tensorflow.python.platform import test -from tensorflow.python.util import tf_inspect + + +# The Python parser only briefly captures decorators into the AST. +# The interpreter desugars them on load, and the decorated function loses any +# trace of the decorator (which is notmally what you would expect, since +# they are meant to be transparent). +# However, decorators are still visible when you analyze the function +# from inside a decorator, before it was applied - as is the case +# with our conversion decorators. + + +def simple_decorator(f): + return lambda a: f(a) + 1 + + +def self_removing_decorator(removing_wrapper): + def decorator(f): + @wraps(f) + def wrapper(*args): + # This removing wrapper is defined in the test below. This setup is so + # intricate just to simulate how we use the transformer in practice. + transformed_f = removing_wrapper(f, (self_removing_decorator,)) + return transformed_f(*args) + 1 + return wrapper + return decorator class DecoratorsTest(converter_test_base.TestCase): - def test_function_decorator(self): + def _remover_wrapper(self, f, remove_decorators): + namespace = { + 'self_removing_decorator': self_removing_decorator, + 'simple_decorator': simple_decorator + } + node = self.parse_and_analyze(f, namespace) + node, _ = decorators.transform(node, remove_decorators=remove_decorators) + result, _ = compiler.ast_to_object(node) + return getattr(result, f.__name__) - def function_decorator(): + def test_noop(self): - def decorator(f): - return lambda a: f(a) + 1 + def test_fn(a): + return a - return decorator + node = self.parse_and_analyze(test_fn, {}) + node, deps = decorators.transform(node, remove_decorators=()) + result, _ = compiler.ast_to_object(node) - # The Python parser does capture decorators into the AST. - # However, the interpreter desugars them on load, and refering to the - # decorated function at runtime usually loses any trace of the decorator. - # Below is an example when that doesn't happen. - def static_wrapper(): + self.assertFalse(deps) + self.assertEqual(1, result.test_fn(1)) - @function_decorator() - def test_fn(a): # pylint:disable=unused-variable - return a + def test_function(self): - node = self.parse_and_analyze(static_wrapper, - {'function_decorator': function_decorator}) - node = node.body[0].body[0] + @self_removing_decorator(self._remover_wrapper) + def test_fn(a): + return a - node = decorators.transform(node, remove_decorators=()) - # Since the decorator is not removed, we need to include its source - # code. We cannot do it after the fact because decorators are executed - # on load. - result, _ = compiler.ast_to_object( - node, - source_prefix=textwrap.dedent(tf_inspect.getsource(function_decorator))) - self.assertEqual(2, result.test_fn(1)) + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, test_fn(1)) - node = decorators.transform(node, remove_decorators=(function_decorator,)) - with self.compiled(node) as result: - self.assertEqual(1, result.test_fn(1)) + def test_method(self): - def test_simple_decorator(self): + class TestClass(object): - def simple_decorator(f): - return lambda a: f(a) + 1 + @self_removing_decorator(self._remover_wrapper) + def test_fn(self, a): + return a - # The Python parser does capture decorators into the AST. - # However, the interpreter desugars them upon load, and refering to the - # decorated function at runtime usually loses any trace of the decorator. - # Below is an example when that doesn't happen. - def static_wrapper(): + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, TestClass().test_fn(1)) - @simple_decorator - def test_fn(a): # pylint:disable=unused-variable + def test_multiple_decorators(self): + + class TestClass(object): + + # Note that reversing the order of this two doesn't work. + @classmethod + @self_removing_decorator(self._remover_wrapper) + def test_fn(cls, a): return a - node = self.parse_and_analyze(static_wrapper, - {'simple_decorator': simple_decorator}) - node = node.body[0].body[0] - - node = decorators.transform(node, remove_decorators=()) - # Since the decorator is not removed, we need to include its source - # code. We cannot do it after the fact because decorators are executed - # on load. - result, _ = compiler.ast_to_object( - node, - source_prefix=textwrap.dedent(tf_inspect.getsource(simple_decorator))) - self.assertEqual(2, result.test_fn(1)) - - node = decorators.transform(node, remove_decorators=(simple_decorator,)) - with self.compiled(node) as result: - self.assertEqual(1, result.test_fn(1)) + # 2 = 1 (a) + 1 (decorator applied exactly once) + self.assertEqual(2, TestClass.test_fn(1)) + + def test_nested_decorators(self): + + @self_removing_decorator(self._remover_wrapper) + def test_fn(a): + @simple_decorator + def inner_fn(b): + return b + 11 + return inner_fn(a) + + with self.assertRaises(ValueError): + test_fn(1) + + # TODO(mdan): Uncomment this test once converter_test_base is updated. + # (can't do it now because it has unrelated pending changes) + # def test_nested_decorators(self): + # + # @self_removing_decorator(self._remover_wrapper) + # def test_fn(a): + # @imported_decorator + # def inner_fn(b): + # return b + 11 + # return inner_fn(a) + # + # # 14 = 1 (a) + 1 (simple_decorator) + 11 (inner_fn) + # self.assertEqual(14, test_fn(1)) if __name__ == '__main__': diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization.py b/tensorflow/contrib/py2tf/converters/for_loops.py similarity index 100% rename from tensorflow/contrib/py2tf/converters/for_canonicalization.py rename to tensorflow/contrib/py2tf/converters/for_loops.py diff --git a/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py b/tensorflow/contrib/py2tf/converters/for_loops_test.py similarity index 88% rename from tensorflow/contrib/py2tf/converters/for_canonicalization_test.py rename to tensorflow/contrib/py2tf/converters/for_loops_test.py index 910c4dcc0081a5632e5324268c15fd3bde5d875b..70a367d3b517e528b67f260d607431d324d2ab7d 100644 --- a/tensorflow/contrib/py2tf/converters/for_canonicalization_test.py +++ b/tensorflow/contrib/py2tf/converters/for_loops_test.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for for_canonicalization module.""" +"""Tests for for_loops module.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.py2tf.converters import converter_test_base -from tensorflow.contrib.py2tf.converters import for_canonicalization +from tensorflow.contrib.py2tf.converters import for_loops from tensorflow.python.platform import test @@ -34,7 +34,7 @@ class ControlFlowTest(converter_test_base.TestCase): return s node = self.parse_and_analyze(test_fn, {}) - node = for_canonicalization.transform(node, self.ctx) + node = for_loops.transform(node, self.ctx) with self.compiled(node) as result: l = [1, 2, 3] diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension.py b/tensorflow/contrib/py2tf/converters/list_comprehension.py new file mode 100644 index 0000000000000000000000000000000000000000..e8744831100e4852919b5cd1253b74acea4d790d --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/list_comprehension.py @@ -0,0 +1,80 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Canonicalizing list comprehensions into for and if statements. + +e.g. +result = [x * x for x in xs] + +becomes + +result = [] +for x in xs: + elt = x * x + result.append(elt) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer + + +class ListCompCanonicalizationTransformer(transformer.Base): + """NodeTransformer to canonicalize list comprehensions.""" + + def __init__(self, context): + super(ListCompCanonicalizationTransformer, self).__init__(context) + + def make_update_list_node(self, list_, elt): + return templates.replace('list_.append(elt)', list_=list_, elt=elt)[0] + + def instantiate_list_node(self): + return parser.parse_str('[]').body[0].value + + def visit_Assign(self, node): + if not isinstance(node.value, gast.ListComp): + return node + if len(node.targets) > 1: + raise ValueError('Only support single assignment.') + return self.canonicalize_listcomp(node.targets[0], node.value) + + def canonicalize_listcomp(self, result_node, list_comp_node): + + make_list = templates.replace( + 'list_ = create_list', + list_=result_node, + create_list=self.instantiate_list_node()) + loop_body = self.make_update_list_node(result_node, list_comp_node.elt) + + for gen in reversed(list_comp_node.generators): + for gen_if in reversed(gen.ifs): + loop_body = templates.replace( + 'if test: loop_body', test=gen_if, loop_body=loop_body) + loop_body = templates.replace( + 'for target in iter_: loop_body', + iter_=gen.iter, + target=gen.target, + loop_body=loop_body) + + return make_list + loop_body + + +def transform(node, context): + return ListCompCanonicalizationTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/list_comprehension_test.py b/tensorflow/contrib/py2tf/converters/list_comprehension_test.py new file mode 100644 index 0000000000000000000000000000000000000000..025fac11e41e6771fbb9b80ff3da70dc3ceec73e --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/list_comprehension_test.py @@ -0,0 +1,75 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for list_comprehension module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import list_comprehension +from tensorflow.python.platform import test + + +class ListCompTest(converter_test_base.TestCase): + + def test_basic(self): + + def test_fn(l): + s = [e * e for e in l] + return s + + node = self.parse_and_analyze(test_fn, {}) + node = list_comprehension.transform(node, self.ctx) + + with self.compiled(node) as result: + l = [1, 2, 3] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + def test_multiple_generators(self): + + def test_fn(l): + s = [e * e for sublist in l for e in sublist] + return s + + node = self.parse_and_analyze(test_fn, {}) + node = list_comprehension.transform(node, self.ctx) + + with self.compiled(node) as result: + l = [[1], [2], [3]] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + def test_conds(self): + + def test_fn(l): + s = [e * e for e in l if e > 1] + return s + + node = self.parse_and_analyze(test_fn, {}) + node = list_comprehension.transform(node, self.ctx) + + with self.compiled(node) as result: + l = [1, 2, 3] + self.assertEqual(test_fn(l), result.test_fn(l)) + l = [] + self.assertEqual(test_fn(l), result.test_fn(l)) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/converters/name_scopes.py b/tensorflow/contrib/py2tf/converters/name_scopes.py new file mode 100644 index 0000000000000000000000000000000000000000..c702823fcf047fcad3254318bd323d2b8fddd700 --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/name_scopes.py @@ -0,0 +1,52 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wraps a function body with a `name_scope` of the function name. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gast + +from tensorflow.contrib.py2tf.pyct import templates +from tensorflow.contrib.py2tf.pyct import transformer + + +class FunctionNameScopeTransformer(transformer.Base): + """Wrap a function body with a `name_scope` of the function name.""" + + def __init__(self, context): + super(FunctionNameScopeTransformer, self).__init__(context) + self._function_level = 0 + + def visit_FunctionDef(self, node): + self._function_level += 1 + try: + self.generic_visit(node) + finally: + self._function_level -= 1 + scope_name = node.name + if self._function_level == 0 and self.context.owner_type is not None: + scope_name = '{}/{}'.format(self.context.owner_type.__name__, scope_name) + node.body = templates.replace( + 'with tf.name_scope(scope_name): body', + scope_name=gast.Str(scope_name), + body=node.body) + return node + + +def transform(node, context): + return FunctionNameScopeTransformer(context).visit(node) diff --git a/tensorflow/contrib/py2tf/converters/name_scopes_test.py b/tensorflow/contrib/py2tf/converters/name_scopes_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a8ca341602ee5f06dbb812643a58794339d98afe --- /dev/null +++ b/tensorflow/contrib/py2tf/converters/name_scopes_test.py @@ -0,0 +1,92 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for for_canonicalization module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.converters import converter_test_base +from tensorflow.contrib.py2tf.converters import name_scopes +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.platform import test + + +class FunctionNameScopeTransformer(converter_test_base.TestCase): + + def test_basic_name(self): + + def test_fn(l): + a = 5 + l += a + return l + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant([1, 2, 3])) + self.assertIn('test_fn/', result_op.op.name) + + def test_nested_name(self): + + def test_fn(l): + + def body(i): + return i**2 + + l += [4] + return body(l) + + node = self.parse_and_analyze(test_fn, {}) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.test_fn(constant_op.constant([1, 2, 3])) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('test_fn/', first_result_input_name) + self.assertNotIn('body/', first_result_input_name) + self.assertIn('test_fn/body/', second_result_input_name) + + def test_class_name(self): + + class TestClass(object): + + def test_fn(self, l): + + def body(i): + return i**2 + + l += [4] + return body(l) + + # Note that 'TestClass' was needed in the namespace here. + node = self.parse_and_analyze( + TestClass, {'TestClass': TestClass}, owner_type=TestClass) + node = name_scopes.transform(node, self.ctx) + + with self.compiled(node, ops.name_scope) as result: + result_op = result.TestClass().test_fn(constant_op.constant([1, 2, 3])) + first_result_input_name = result_op.op.inputs[0].name + second_result_input_name = result_op.op.inputs[1].name + self.assertIn('TestClass/test_fn/', first_result_input_name) + self.assertNotIn('body/', first_result_input_name) + self.assertIn('TestClass/test_fn/body/', second_result_input_name) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/impl/BUILD b/tensorflow/contrib/py2tf/impl/BUILD index f5378917a305b538e6a943816abf70e4059c1748..90ffabbc9bf4524ec2ebf54b6dd847bd8768a486 100644 --- a/tensorflow/contrib/py2tf/impl/BUILD +++ b/tensorflow/contrib/py2tf/impl/BUILD @@ -28,6 +28,7 @@ py_library( "//tensorflow/contrib/py2tf/converters", "//tensorflow/contrib/py2tf/pyct", "//tensorflow/contrib/py2tf/pyct/static_analysis", + "//tensorflow/contrib/py2tf/utils", "@gast_archive//:gast", "@six_archive//:six", ], diff --git a/tensorflow/contrib/py2tf/impl/api.py b/tensorflow/contrib/py2tf/impl/api.py index 8ae1c701698ae9a4efbde45222ff6c3db6e92521..48100aac32844f5f10604b9c7a544c76d0b04eed 100644 --- a/tensorflow/contrib/py2tf/impl/api.py +++ b/tensorflow/contrib/py2tf/impl/api.py @@ -26,7 +26,9 @@ import six from tensorflow.contrib.py2tf.impl import config from tensorflow.contrib.py2tf.impl import conversion from tensorflow.contrib.py2tf.pyct import compiler +from tensorflow.contrib.py2tf.pyct import inspect_utils from tensorflow.contrib.py2tf.pyct import parser +from tensorflow.contrib.py2tf.utils import builtins from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_inspect @@ -110,28 +112,7 @@ def convert(recursive=False, verbose=False, arg_types=None): @wraps(f) def wrapper(*args, **kwargs): - """Wrapper that calls the compiled version of the wrapped function.""" - partial_types = () - arg_values = {} - arg_names = tf_inspect.getargspec(f)[0] - for name, arg in zip(arg_names, args): - arg_values[name] = arg - arg_class = arg.__class__ - # If arg_value_hints specifies any name, use that instead. - if name not in arg_types: - arg_types[name] = (arg_class.__name__, arg_class) - if name == 'self' and tf_inspect.isclass(arg_class): - # Annotated methods need to specify that their owner type is partial, - # otherwise other members they call will not be converted. - partial_types = (arg_class,) - wrapped = to_graph( - f, - recursive=recursive, - verbose=verbose, - arg_values=arg_values, - arg_types=arg_types, - partial_types=partial_types) - return wrapped(*args, **kwargs) + return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) # Sometimes the decorator is just desugared, making it impossible to detect. # This attribute makes detection easier. @@ -141,6 +122,78 @@ def convert(recursive=False, verbose=False, arg_types=None): return decorator +def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): + """Compiles a function call inline.""" + # TODO(mdan): This needs cleanup. + # In particular, we may want to avoid renaming functions altogether. + + if conversion.is_whitelisted_for_graph(f): + return f(*args, **kwargs) + + unknown_arg_value = object() # Sentinel for arguments of unknown value + + if tf_inspect.isbuiltin(f): + return builtins.dynamic_builtin(f, *args, **kwargs) + + if tf_inspect.isfunction(f) or tf_inspect.ismethod(f): + # Regular functions + target_entity = f + arg_map_target = f + effective_args = args + f_class = inspect_utils.getmethodclass(f) + + if f_class is not None: + partial_types = (f_class,) + else: + partial_types = () + + elif tf_inspect.isclass(f): + # Constructors + target_entity = f + arg_map_target = f.__init__ + effective_args = (unknown_arg_value,) + args + partial_types = () + + elif hasattr(f, '__call__') and hasattr(f, '__class__'): + # Callable objects + target_entity = f.__call__ + arg_map_target = f.__call__ + effective_args = (f,) + args + partial_types = (f.__class__,) + + else: + NotImplementedError('unknown callable type "%s"' % type(f)) + + arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs) + for name, arg in arg_values.items(): + if arg is unknown_arg_value: + continue + arg_class = arg.__class__ + # If arg_value_hints specifies any name, use that instead. + if name not in arg_types: + arg_types[name] = (arg_class.__name__, arg_class) + + # When called from within a decorator, this is the only indication that + # the function is a method - it appears that the decorator is applied + # before the method is bound. + if not partial_types: + if 'self' in arg_values: + if tf_inspect.isclass(arg_values['self'].__class__): + partial_types = (arg_values['self'].__class__,) + elif 'cls' in arg_values: + if tf_inspect.isclass(arg_values['cls']): + partial_types = (arg_values['cls'],) + + converted_f = to_graph( + target_entity, + recursive=recursive, + verbose=verbose, + arg_values=arg_values, + arg_types=arg_types, + partial_types=partial_types) + return converted_f(*effective_args, **kwargs) + + def to_graph(e, recursive=True, verbose=False, @@ -175,7 +228,8 @@ def to_graph(e, conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), - partial_types=partial_types) + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) module = gast.Module([]) @@ -188,7 +242,7 @@ def to_graph(e, # The compiled code should see everything the entry function saw. # TODO(mdan): This might not work well if the call tree spans modules? if tf_inspect.isfunction(e): - compiled_node.__dict__.update(six.get_function_globals(e)) + compiled_node.__dict__.update(inspect_utils.getnamespace(e)) compiled_fn = getattr(compiled_node, name) if verbose: @@ -221,7 +275,8 @@ def to_code(e, conversion_map = conversion.ConversionMap( recursive=recursive, nocompile_decorators=(convert, graph_ready, convert_inline), - partial_types=partial_types) + partial_types=partial_types, + api_module=tf_inspect.getmodule(to_graph)) conversion.entity_to_graph(e, conversion_map, arg_values, arg_types) imports = '\n'.join(config.COMPILED_IMPORT_STATEMENTS) diff --git a/tensorflow/contrib/py2tf/impl/api_test.py b/tensorflow/contrib/py2tf/impl/api_test.py index 02cd8ed2d0ffee8ef2d31ea65902d2b493df9d64..51e99864adeba9c928b6e74eb759054ef1d1d78c 100644 --- a/tensorflow/contrib/py2tf/impl/api_test.py +++ b/tensorflow/contrib/py2tf/impl/api_test.py @@ -31,8 +31,8 @@ class ApiTest(test.TestCase): def setUp(self): config.DEFAULT_UNCOMPILED_MODULES.add((math_ops.__name__,)) config.COMPILED_IMPORT_STATEMENTS = ( - 'from tensorflow.python.ops ' - 'import control_flow_ops as tf', + 'from tensorflow.python.framework ' + 'import ops as tf', 'from tensorflow.contrib.py2tf import utils as ' 'py2tf_utils') diff --git a/tensorflow/contrib/py2tf/impl/config.py b/tensorflow/contrib/py2tf/impl/config.py index 6525806a0933dd9f0a237e278bb70b88346bea27..bdbc6663dd65ed66c55ad2d2e52428084bbea219 100644 --- a/tensorflow/contrib/py2tf/impl/config.py +++ b/tensorflow/contrib/py2tf/impl/config.py @@ -18,6 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.py2tf import utils + + PYTHON_LITERALS = { 'None': None, 'False': False, @@ -27,14 +30,21 @@ PYTHON_LITERALS = { DEFAULT_UNCOMPILED_MODULES = set(( ('tensorflow',), + (utils.__name__,), + + # All of tensorflow's subpackages. Unlike the root tf module, they don't + # have well-known names. Not refering to the module directly to avoid + # circular imports. + (utils.__name__[:-len('.contrib.py2tf.utils')],), )) NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',)) # TODO(mdan): Also allow controlling the generated names (for testability). -# TODO(mdan): Verify that these names are not hidden by generated code. -# TODO(mdan): Make sure copybara renames the reference below. COMPILED_IMPORT_STATEMENTS = ( + 'from __future__ import print_function', 'import tensorflow as tf', + 'from tensorflow.contrib.py2tf.impl import api as ' + 'py2tf_api', 'from tensorflow.contrib.py2tf import utils as ' 'py2tf_utils') diff --git a/tensorflow/contrib/py2tf/impl/conversion.py b/tensorflow/contrib/py2tf/impl/conversion.py index ff4f159975578dada45542df39f7ebbb61dd2e36..d95469ea532d5c3acc44d1e65b852f27714b8049 100644 --- a/tensorflow/contrib/py2tf/impl/conversion.py +++ b/tensorflow/contrib/py2tf/impl/conversion.py @@ -19,21 +19,23 @@ from __future__ import division from __future__ import print_function import gast -import six +from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.converters import asserts -from tensorflow.contrib.py2tf.converters import break_canonicalization +from tensorflow.contrib.py2tf.converters import break_statements from tensorflow.contrib.py2tf.converters import builtin_functions from tensorflow.contrib.py2tf.converters import call_trees -from tensorflow.contrib.py2tf.converters import continue_canonicalization +from tensorflow.contrib.py2tf.converters import continue_statements from tensorflow.contrib.py2tf.converters import control_flow from tensorflow.contrib.py2tf.converters import decorators -from tensorflow.contrib.py2tf.converters import for_canonicalization +from tensorflow.contrib.py2tf.converters import for_loops from tensorflow.contrib.py2tf.converters import logical_expressions +from tensorflow.contrib.py2tf.converters import name_scopes from tensorflow.contrib.py2tf.converters import side_effect_guards from tensorflow.contrib.py2tf.impl import config from tensorflow.contrib.py2tf.impl import naming from tensorflow.contrib.py2tf.pyct import context +from tensorflow.contrib.py2tf.pyct import inspect_utils from tensorflow.contrib.py2tf.pyct import parser from tensorflow.contrib.py2tf.pyct import qual_names from tensorflow.contrib.py2tf.pyct.static_analysis import activity @@ -55,18 +57,26 @@ class ConversionMap(object): off. dependency_cache: dict[object]: ast; maps original entities to their converted AST + additional_imports: set(object); additional entities which for any reason + cannot be attached after loading and need to be explicitly imported + in the generated code name_map: dict[string]: string; maps original entities to the name of their converted counterparts + api_module: A reference to the api module. The reference needs to be passed + to avoid circular dependencies. """ # TODO(mdan): Rename to ConversionContext, and pull in additional flags. - def __init__(self, recursive, nocompile_decorators, partial_types): + def __init__(self, recursive, nocompile_decorators, partial_types, + api_module): self.recursive = recursive self.nocompile_decorators = nocompile_decorators self.partial_types = partial_types if partial_types else () self.dependency_cache = {} + self.additional_imports = set() self.name_map = {} + self.api_module = api_module def new_namer(self, namespace): return naming.Namer(namespace, self.recursive, self.name_map, @@ -87,6 +97,24 @@ class ConversionMap(object): self.dependency_cache[original_entity] = converted_ast +def is_whitelisted_for_graph(o): + """Check whether an entity is whitelisted for use in graph mode. + + Examples of whitelisted entities include all members of the tensorflow + package. + + Args: + o: A Python entity. + Returns: + Boolean + """ + m = tf_inspect.getmodule(o) + for prefix, in config.DEFAULT_UNCOMPILED_MODULES: + if m.__name__.startswith(prefix): + return True + return False + + def entity_to_graph(o, conversion_map, arg_values, arg_types): """Compile a Python entity into equivalent TensorFlow. @@ -145,7 +173,7 @@ def class_to_graph(c, conversion_map): if not members: raise ValueError('Cannot convert %s: it has no member methods.') - class_globals = None + class_namespace = None for _, m in members: node, _ = function_to_graph( m, @@ -154,10 +182,10 @@ def class_to_graph(c, conversion_map): arg_types={'self': (c.__name__, c)}, owner_type=c) # TODO(mdan): Do not assume all members have the same view of globals. - if class_globals is None: - class_globals = six.get_function_globals(m) + if class_namespace is None: + class_namespace = inspect_utils.getnamespace(m) converted_members[m] = node - namer = conversion_map.new_namer(class_globals) + namer = conversion_map.new_namer(class_namespace) class_name = namer.compiled_class_name(c.__name__, c) node = gast.ClassDef( class_name, @@ -169,22 +197,34 @@ def class_to_graph(c, conversion_map): return node, class_name +def _add_self_references(namespace, api_module): + """Self refs are only required for analysis and are not used directly.""" + # Manually add the utils namespace which may be used from generated code. + if 'py2tf_util' not in namespace: + namespace['py2tf_utils'] = utils + elif namespace['py2tf_utils'] != utils: + raise ValueError( + 'The module name "py2tf_utils" is reserved and may not be used.') + + # We also make reference to the api module for dynamic conversion, but + # to avoid circular references we don't import it here. + if 'py2tf_api' not in namespace: + namespace['py2tf_api'] = api_module + elif namespace['py2tf_api'] != api_module: + raise ValueError( + 'The module name "py2tf_api" is reserved and may not be used.') + + def function_to_graph(f, conversion_map, arg_values, arg_types, owner_type=None): """Specialization of `entity_to_graph` for callable functions.""" node, source = parser.parse_entity(f) node = node.body[0] - namespace = six.get_function_globals(f) - - # This is needed for non-global functions. - closure = six.get_function_closure(f) - if closure: - for e in closure: - if callable(e.cell_contents): - fn = e.cell_contents - namespace[fn.__name__] = fn + namespace = inspect_utils.getnamespace(f) + _add_self_references(namespace, conversion_map.api_module) namer = conversion_map.new_namer(namespace) + ctx = context.EntityContext( namer=namer, source_code=source, @@ -192,8 +232,9 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, namespace=namespace, arg_values=arg_values, arg_types=arg_types, + owner_type=owner_type, recursive=conversion_map.recursive) - node = node_to_graph(node, ctx, conversion_map.nocompile_decorators) + node, deps = node_to_graph(node, ctx, conversion_map.nocompile_decorators) # TODO(mdan): This somewhat duplicates the call rename logic in call_treest.py new_name, did_rename = namer.compiled_function_name(f.__name__, f, owner_type) @@ -204,6 +245,9 @@ def function_to_graph(f, conversion_map, arg_values, arg_types, node.name = new_name conversion_map.update_name_map(namer) + # TODO(mdan): Use this at compilation. + conversion_map.additional_imports.update(deps) + return node, new_name @@ -246,22 +290,20 @@ def node_to_graph(node, ctx, nocompile_decorators): # source. # TODO(mdan): Is it feasible to reconstruct intermediate source code? ctx.source_code = None - node = decorators.transform(node, nocompile_decorators) - node = break_canonicalization.transform(node, ctx) + node, deps = decorators.transform(node, nocompile_decorators) + node = break_statements.transform(node, ctx) node = asserts.transform(node, ctx) # Note: sequencing continue canonicalization before for loop one avoids # dealing with the extra loop increment operation that the for # canonicalization creates. - node = continue_canonicalization.transform(node, ctx) + node = continue_statements.transform(node, ctx) ctx.namespace['len'] = len node = _static_analysis_pass(node, ctx) - node = for_canonicalization.transform(node, ctx) - # for_canonicalization may insert new global references. + node = for_loops.transform(node, ctx) + # for_loops may insert new global references. node = builtin_functions.transform(node, ctx) - # builtin_functions may insert new global references. - ctx.namespace['print'] = print node = _static_analysis_pass(node, ctx) node = call_trees.transform(node, ctx, config.DEFAULT_UNCOMPILED_MODULES, @@ -272,5 +314,6 @@ def node_to_graph(node, ctx, nocompile_decorators): node = _static_analysis_pass(node, ctx) node = logical_expressions.transform(node) node = side_effect_guards.transform(node, ctx) + node = name_scopes.transform(node, ctx) - return node + return node, deps diff --git a/tensorflow/contrib/py2tf/impl/conversion_test.py b/tensorflow/contrib/py2tf/impl/conversion_test.py index 3888958f19b9fa13b759924c5188722e500e30a1..9ff256aace7a0e7ac5e7ac07e580b8bed7d8df6f 100644 --- a/tensorflow/contrib/py2tf/impl/conversion_test.py +++ b/tensorflow/contrib/py2tf/impl/conversion_test.py @@ -20,15 +20,26 @@ from __future__ import print_function import gast +from tensorflow.contrib.py2tf import utils from tensorflow.contrib.py2tf.impl import conversion +from tensorflow.python.framework import constant_op from tensorflow.python.platform import test class ConversionTest(test.TestCase): + def test_is_whitelisted_for_graph(self): + + def test_fn(): + return constant_op.constant(1) + + self.assertFalse(conversion.is_whitelisted_for_graph(test_fn)) + self.assertTrue(conversion.is_whitelisted_for_graph(utils)) + self.assertTrue(conversion.is_whitelisted_for_graph(constant_op.constant)) + def test_entity_to_graph_unsupported_types(self): with self.assertRaises(ValueError): - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph('dummy', conversion_map, None, None) def test_entity_to_graph_callable(self): @@ -36,7 +47,7 @@ class ConversionTest(test.TestCase): def f(a): return a - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = conversion.ConversionMap(True, (), (), None) ast, new_name = conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(isinstance(ast, gast.FunctionDef), ast) self.assertEqual('tf__f', new_name) @@ -49,14 +60,17 @@ class ConversionTest(test.TestCase): def f(a): return g(a) - conversion_map = conversion.ConversionMap(True, (), ()) + conversion_map = conversion.ConversionMap(True, (), (), None) conversion.entity_to_graph(f, conversion_map, None, None) self.assertTrue(f in conversion_map.dependency_cache) self.assertTrue(g in conversion_map.dependency_cache) self.assertEqual('tf__f', conversion_map.dependency_cache[f].name) + # need the extra .body[0] in order to step past the with tf.name_scope('f') + # that is added automatically self.assertEqual( - 'tf__g', conversion_map.dependency_cache[f].body[0].value.func.id) + 'tf__g', + conversion_map.dependency_cache[f].body[0].body[0].value.func.id) self.assertEqual('tf__g', conversion_map.dependency_cache[g].name) diff --git a/tensorflow/contrib/py2tf/pyct/BUILD b/tensorflow/contrib/py2tf/pyct/BUILD index e3c0da4b10f9ffbee1b2a906b64d4762f41d97b4..edec5f7712d08247437c9e95d743e59dafffcd7b 100644 --- a/tensorflow/contrib/py2tf/pyct/BUILD +++ b/tensorflow/contrib/py2tf/pyct/BUILD @@ -24,6 +24,7 @@ py_library( "ast_util.py", "compiler.py", "context.py", + "inspect_utils.py", "parser.py", "pretty_printer.py", "qual_names.py", @@ -72,6 +73,17 @@ py_test( ], ) +py_test( + name = "inspect_utils_test", + srcs = ["inspect_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pyct", + "//tensorflow/python:client_testlib", + "@gast_archive//:gast", + ], +) + py_test( name = "parser_test", srcs = ["parser_test.py"], diff --git a/tensorflow/contrib/py2tf/pyct/compiler.py b/tensorflow/contrib/py2tf/pyct/compiler.py index 0caadf18c0db2a5e557c94f4df7a3f7a7321bd60..51cf6930e8bcb3728ee55bf5d4781f01a5ef73bd 100644 --- a/tensorflow/contrib/py2tf/pyct/compiler.py +++ b/tensorflow/contrib/py2tf/pyct/compiler.py @@ -22,6 +22,7 @@ from __future__ import division from __future__ import print_function # TODO(mdan): Use six for compatibility here. +import atexit import imp import os import tempfile @@ -41,7 +42,8 @@ def ast_to_source(node, indentation): return astor.source_repr.pretty_source(generator.result).lstrip() -def ast_to_object(node, indentation=' ', source_prefix=None): +def ast_to_object( + node, indentation=' ', source_prefix=None, delete_on_exit=True): """Return the Python objects represented by given AST. Compiling the AST code this way ensures that the source code is readable by @@ -51,6 +53,8 @@ def ast_to_object(node, indentation=' ', source_prefix=None): node: The code to compile, as an AST object. indentation: The string to use for indentation. source_prefix: Optional string to print as-is into the source file. + delete_on_exit: Whether to delete the temporary file used for compilation + on exit. Returns: A module object containing the compiled source code. @@ -63,4 +67,6 @@ def ast_to_object(node, indentation=' ', source_prefix=None): f.write(source_prefix) f.write('\n') f.write(source) + if delete_on_exit: + atexit.register(lambda: os.remove(f.name)) return imp.load_source(module_name, f.name), source diff --git a/tensorflow/contrib/py2tf/pyct/context.py b/tensorflow/contrib/py2tf/pyct/context.py index fef74ebefa290369c7310af6d7e4faeef44d9aee..4fcf2a687d58af951adfc0dcf52ff7303d2b17f5 100644 --- a/tensorflow/contrib/py2tf/pyct/context.py +++ b/tensorflow/contrib/py2tf/pyct/context.py @@ -30,14 +30,16 @@ class EntityContext(object): (excluding parameters). arg_values: Dict[str->*], containing parameter values, if known. arg_types: Dict[str->*], containing parameter types, if known. + owner_type: The surrounding class type of the function, if present. """ def __init__(self, namer, source_code, source_file, namespace, arg_values, - arg_types, recursive): + arg_types, owner_type, recursive): self.namer = namer self.source_code = source_code self.source_file = source_file self.namespace = namespace self.arg_values = {} if arg_values is None else arg_values self.arg_types = {} if arg_types is None else arg_types + self.owner_type = owner_type self.recursive = recursive diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils.py b/tensorflow/contrib/py2tf/pyct/inspect_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d19c6ed75e0f0651781d6e1ed80f7be11fb8a5a4 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/inspect_utils.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Live entity inspection utilities. + +This module contains whatever inspect doesn't offer out of the box. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import six + +from tensorflow.python.util import tf_inspect + + +def getnamespace(f): + """Returns the complete namespace of a function. + + Namespace is defined here as the mapping of all non-local variables to values. + This includes the globals and the closure variables. Note that this captures + the entire globals collection of the function, and may contain extra symbols + that it does not actually use. + + Args: + f: User defined function. + Returns: + A dict mapping symbol names to values. + """ + namespace = dict(six.get_function_globals(f)) + closure = six.get_function_closure(f) + freevars = six.get_function_code(f).co_freevars + if freevars and closure: + for name, cell in zip(freevars, closure): + namespace[name] = cell.cell_contents + return namespace + + +def getmethodclass(m): + """Resolves a function's owner, e.g. a method's class. + + Note that this returns the object that the function was retrieved from, not + necessarily the class where it was defined. + + This function relies on Python stack frame support in the interpreter, and + has the same limitations that inspect.currentframe. + + Limitations. This function will only work correctly if the owned class is + visible in the caller's global or local variables. + + Args: + m: A user defined function + + Returns: + The class that this function was retrieved from, or None if the function + is not an object or class method, or the class that owns the object or + method is not visible to m. + + Raises: + ValueError: if the class could not be resolved for any unexpected reason. + """ + + # Instance method and class methods: should be bound to a non-null "self". + # If self is a class, then it's a class method. + if hasattr(m, '__self__'): + if m.__self__: + if tf_inspect.isclass(m.__self__): + return m.__self__ + return type(m.__self__) + + # Class, static and unbound methods: search all defined classes in any + # namespace. This is inefficient but more robust method. + owners = [] + caller_frame = tf_inspect.currentframe().f_back + try: + # TODO(mdan): This doesn't consider cell variables. + # TODO(mdan): This won't work if the owner is hidden inside a container. + # Cell variables may be pulled using co_freevars and the closure. + for v in itertools.chain(caller_frame.f_locals.values(), + caller_frame.f_globals.values()): + if hasattr(v, m.__name__): + candidate = getattr(v, m.__name__) + # Py2 methods may be bound or unbound, extract im_func to get the + # underlying function. + if hasattr(candidate, 'im_func'): + candidate = candidate.im_func + if hasattr(m, 'im_func'): + m = m.im_func + if candidate is m: + owners.append(v) + finally: + del caller_frame + + if owners: + if len(owners) == 1: + return owners[0] + + # If multiple owners are found, and are not subclasses, raise an error. + owner_types = tuple(o if tf_inspect.isclass(o) else type(o) for o in owners) + for o in owner_types: + if tf_inspect.isclass(o) and issubclass(o, tuple(owner_types)): + return o + raise ValueError('Found too many owners of %s: %s' % (m, owners)) + + return None diff --git a/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py b/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5528ac851f74bd7b7dacdbe7b930945afa8c9783 --- /dev/null +++ b/tensorflow/contrib/py2tf/pyct/inspect_utils_test.py @@ -0,0 +1,230 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for unspect_utils module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from functools import wraps + +import six + +from tensorflow.contrib.py2tf.pyct import inspect_utils +from tensorflow.python.platform import test + + +def decorator(f): + return f + + +def function_decorator(): + def dec(f): + return f + return dec + + +def wrapping_decorator(): + def dec(f): + def replacement(*_): + return None + + @wraps(f) + def wrapper(*args, **kwargs): + return replacement(*args, **kwargs) + return wrapper + return dec + + +class TestClass(object): + + def member_function(self): + pass + + @decorator + def decorated_member(self): + pass + + @function_decorator() + def fn_decorated_member(self): + pass + + @wrapping_decorator() + def wrap_decorated_member(self): + pass + + @staticmethod + def static_method(): + pass + + @classmethod + def class_method(cls): + pass + + +def free_function(): + pass + + +def factory(): + return free_function + + +def free_factory(): + def local_function(): + pass + return local_function + + +class InspectUtilsTest(test.TestCase): + + def test_getnamespace_globals(self): + ns = inspect_utils.getnamespace(factory) + self.assertEqual(ns['free_function'], free_function) + + def test_getnamespace_hermetic(self): + + # Intentionally hiding the global function to make sure we don't overwrite + # it in the global namespace. + free_function = object() # pylint:disable=redefined-outer-name + + def test_fn(): + return free_function + + ns = inspect_utils.getnamespace(test_fn) + globs = six.get_function_globals(test_fn) + self.assertTrue(ns['free_function'] is free_function) + self.assertFalse(globs['free_function'] is free_function) + + def test_getnamespace_locals(self): + + def called_fn(): + return 0 + + closed_over_list = [] + closed_over_primitive = 1 + + def local_fn(): + closed_over_list.append(1) + local_var = 1 + return called_fn() + local_var + closed_over_primitive + + ns = inspect_utils.getnamespace(local_fn) + self.assertEqual(ns['called_fn'], called_fn) + self.assertEqual(ns['closed_over_list'], closed_over_list) + self.assertEqual(ns['closed_over_primitive'], closed_over_primitive) + self.assertTrue('local_var' not in ns) + + def test_getmethodclass(self): + + self.assertEqual( + inspect_utils.getmethodclass(free_function), None) + self.assertEqual( + inspect_utils.getmethodclass(free_factory()), None) + + self.assertEqual( + inspect_utils.getmethodclass(TestClass.member_function), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.fn_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.wrap_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.static_method), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(TestClass.class_method), + TestClass) + + test_obj = TestClass() + self.assertEqual( + inspect_utils.getmethodclass(test_obj.member_function), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.fn_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.static_method), + TestClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.class_method), + TestClass) + + def test_getmethodclass_locals(self): + + def local_function(): + pass + + class LocalClass(object): + + def member_function(self): + pass + + @decorator + def decorated_member(self): + pass + + @function_decorator() + def fn_decorated_member(self): + pass + + @wrapping_decorator() + def wrap_decorated_member(self): + pass + + self.assertEqual( + inspect_utils.getmethodclass(local_function), None) + + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.member_function), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.fn_decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(LocalClass.wrap_decorated_member), + LocalClass) + + test_obj = LocalClass() + self.assertEqual( + inspect_utils.getmethodclass(test_obj.member_function), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.fn_decorated_member), + LocalClass) + self.assertEqual( + inspect_utils.getmethodclass(test_obj.wrap_decorated_member), + LocalClass) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py index 1c93e1603113d48176af7a97a0f37321e6f67586..02ea6fdeaf78152b6bc48983f79b36f43d4f665d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/activity.py @@ -24,6 +24,7 @@ import gast from tensorflow.contrib.py2tf.pyct import anno from tensorflow.contrib.py2tf.pyct import transformer +from tensorflow.contrib.py2tf.pyct.qual_names import QN from tensorflow.contrib.py2tf.pyct.static_analysis.annos import NodeAnno # TODO(mdan): Add support for PY3 (e.g. Param vs arg). @@ -237,6 +238,18 @@ class ActivityAnalizer(transformer.Base): self.scope.merge_from(after_child) return parent + def visit_FunctionDef(self, node): + if self.scope: + qn = QN(node.name) + self.scope.mark_write(qn) + current_scope = self.scope + fndef_scope = Scope(current_scope, isolated=True) + self.scope = fndef_scope + self.generic_visit(node) + anno.setanno(node, NodeAnno.BODY_SCOPE, fndef_scope) + self.scope = current_scope + return node + def visit_If(self, node): self.visit(node.test) node = self._process_parallel_blocks(node, diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py index e1eb954a5efef4d6a00ac492e7c85394d54e28c9..69f5f4fc582f159e46c8b8929a90ca95b724794d 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/activity_test.py @@ -108,6 +108,7 @@ class ActivityAnalizerTest(test.TestCase): namespace={}, arg_values=None, arg_types=None, + owner_type=None, recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) @@ -239,6 +240,33 @@ class ActivityAnalizerTest(test.TestCase): anno.getanno(if_node, NodeAnno.ORELSE_SCOPE).parent, ('x', 'z', 'u'), ('x', 'y', 'z', 'u'), ('x', 'y', 'z', 'u')) + def test_functiondef(self): + + def test_fn(a): + + def f(x): + y = x * x + return y + + b = a + for i in a: + c = b + b -= f(i) + return b, c + + node = self._parse_and_analyze(test_fn) + fndef_node = node.body[0].body[0] + + self.assertScopeIs( + anno.getanno(fndef_node, + NodeAnno.BODY_SCOPE).parent, ('b', 'i', 'f', 'c', 'a'), + ('f', 'b', 'c', 'i'), ('f', 'a', 'b', 'c', 'i')) + self.assertScopeIs( + anno.getanno(fndef_node, NodeAnno.BODY_SCOPE), ('x', 'y'), ('y',), ( + 'x', + 'y', + )) + def test_call_with_composite_names(self): def foo(*_): diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py index 9c0a9a9e74eccb3d22840032e8f0c2b81e051e7e..0388be5d252389f2f3516c8b27828905d6475589 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values.py @@ -86,6 +86,7 @@ class LiveValueResolver(transformer.Base): if not hasattr(parent_object, node.attr): raise AttributeError('%s has no attribute %s' % (parent_object, node.attr)) + anno.setanno(node, 'parent_type', type(parent_object)) anno.setanno(node, 'live_val', getattr(parent_object, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'fqn') + (node.attr,)) # TODO(mdan): Investigate the role built-in annotations can play here. @@ -96,6 +97,7 @@ class LiveValueResolver(transformer.Base): # This would not hold for dynamic members like function attributes. # For the dynamic case, we simply leave the node without an annotation, # and let downstream consumers figure out what to do. + anno.setanno(node, 'parent_type', parent_type) anno.setanno(node, 'live_val', getattr(parent_type, node.attr)) anno.setanno(node, 'fqn', anno.getanno(node.value, 'type_fqn') + (node.attr,)) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py index 9f64689401e3594a77fbdd7b6f02880bd6e90492..c133a455b3dd328689102634c6076f366212ac25 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/live_values_test.py @@ -46,6 +46,7 @@ class LiveValuesResolverTest(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, + owner_type=None, recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) @@ -102,6 +103,7 @@ class LiveValuesResolverTest(test.TestCase): arg_types={'self': (TestClass.__name__, TestClass)}) func_node = node.body[0].body[0].value.func self.assertEquals(TestClass.member, anno.getanno(func_node, 'live_val')) + self.assertEquals(TestClass, anno.getanno(func_node, 'parent_type')) self.assertEquals(('TestClass', 'member'), anno.getanno(func_node, 'fqn')) diff --git a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py index 3659f949db9910534870d8dd9e42fd4ee8297253..a3e78202c80e45552c038a6a1da763eb30aff52f 100644 --- a/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py +++ b/tensorflow/contrib/py2tf/pyct/static_analysis/type_info_test.py @@ -65,6 +65,7 @@ class TypeInfoResolverTest(test.TestCase): namespace=namespace, arg_values=None, arg_types=arg_types, + owner_type=None, recursive=True) node = qual_names.resolve(node) node = activity.resolve(node, ctx) diff --git a/tensorflow/contrib/py2tf/pyct/templates.py b/tensorflow/contrib/py2tf/pyct/templates.py index c40e4d0fb783191705a412ab2728daabb61eda0f..6ee6c0c5ceb70d87779ee313670135cadc5214b5 100644 --- a/tensorflow/contrib/py2tf/pyct/templates.py +++ b/tensorflow/contrib/py2tf/pyct/templates.py @@ -68,6 +68,10 @@ class ReplaceTransformer(gast.NodeTransformer): if isinstance(node, gast.Attribute): self._set_inner_child_context(node.value, ctx) node.ctx = gast.Load() + elif isinstance(node, gast.Tuple): + for e in node.elts: + self._set_inner_child_context(e, ctx) + node.ctx = ctx elif isinstance(node, gast.Name): node.ctx = ctx elif isinstance(node, (gast.Str, gast.Num)): diff --git a/tensorflow/contrib/py2tf/utils/BUILD b/tensorflow/contrib/py2tf/utils/BUILD index c2987fcace91e82511c48151fb3eb089f24ea35c..2086a9ef6077145ad07f342eea3491f862763158 100644 --- a/tensorflow/contrib/py2tf/utils/BUILD +++ b/tensorflow/contrib/py2tf/utils/BUILD @@ -20,14 +20,18 @@ py_library( name = "utils", srcs = [ "__init__.py", + "builtins.py", "context_managers.py", "misc.py", "multiple_dispatch.py", + "py_func.py", + "tensor_list.py", "type_check.py", ], srcs_version = "PY2AND3", visibility = ["//tensorflow:__subpackages__"], deps = [ + "//tensorflow/python:script_ops", "@six_archive//:six", ], ) @@ -62,6 +66,16 @@ py_test( ], ) +py_test( + name = "py_func_test", + srcs = ["py_func_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "type_check_test", srcs = ["type_check_test.py"], @@ -71,3 +85,14 @@ py_test( "//tensorflow/python:client_testlib", ], ) + +py_test( + name = "tensor_list_test", + srcs = ["tensor_list_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:client_testlib", + "//tensorflow/python:list_ops", + ], +) diff --git a/tensorflow/contrib/py2tf/utils/__init__.py b/tensorflow/contrib/py2tf/utils/__init__.py index 1cbb0e002997ca5df16c28102e1f1092ebfac050..19bf2272bcbcd230575d6654678b381b1d132518 100644 --- a/tensorflow/contrib/py2tf/utils/__init__.py +++ b/tensorflow/contrib/py2tf/utils/__init__.py @@ -18,8 +18,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.py2tf.utils.builtins import dynamic_builtin +from tensorflow.contrib.py2tf.utils.builtins import dynamic_print from tensorflow.contrib.py2tf.utils.context_managers import control_dependency_on_returns from tensorflow.contrib.py2tf.utils.misc import alias_tensors from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_cond from tensorflow.contrib.py2tf.utils.multiple_dispatch import run_while +from tensorflow.contrib.py2tf.utils.py_func import wrap_py_func from tensorflow.contrib.py2tf.utils.type_check import is_tensor diff --git a/tensorflow/contrib/py2tf/utils/builtins.py b/tensorflow/contrib/py2tf/utils/builtins.py new file mode 100644 index 0000000000000000000000000000000000000000..0a50b80b60101afaa9aa0f445079727e9708ac35 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/builtins.py @@ -0,0 +1,73 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builtin conversion utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import py_func +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.util import tf_inspect + + +def dynamic_builtin(f, *args, **kwargs): + """Converts a builtin function call inline.""" + if not tf_inspect.isbuiltin(f): + return f(*args, **kwargs) + + if f is len: + return dynamic_len(*args, **kwargs) + + raise NotImplementedError('The "%s" builtin is not yet supported.' % f) + + +def dynamic_len(list_or_tensor): + """Implementation of len using dynamic dispatch.""" + if tensor_util.is_tensor(list_or_tensor): + shape = list_or_tensor.shape + if not shape: + raise ValueError( + 'len requires non-zero rank for tensor "%s"' % list_or_tensor) + return array_ops.shape(list_or_tensor)[0] + + return len(list_or_tensor) + + +def is_tf_print_compatible(value): + # TODO(mdan): Enable once we can reliably test this. + # This is currently disabled because we can't capture the output of + # op kernels from Python. + del value + return False + + +def dynamic_print(*values): + """Implementartion of print using dynamic dispatch. + + The function attempts to use tf.Print if all the values are compatible. + Otherwise, it will fall back to py_func. + + Args: + *values: values to print + Returns: + A dummy value indicating the print completed. If tf. + """ + + if all(map(is_tf_print_compatible, values)): + return logging_ops.Print(1, values) + return py_func.wrap_py_func(print, None, values, use_dummy_return=True) diff --git a/tensorflow/contrib/py2tf/utils/builtins_test.py b/tensorflow/contrib/py2tf/utils/builtins_test.py new file mode 100644 index 0000000000000000000000000000000000000000..19a72c63ecc873c52abde18e481221fc782ad490 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/builtins_test.py @@ -0,0 +1,78 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for builtins module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import six + +from tensorflow.contrib.py2tf.utils import builtins +from tensorflow.python.framework import constant_op +from tensorflow.python.platform import test + + +class BuiltinsTest(test.TestCase): + + def test_dynamic_len_tf_scalar(self): + a = constant_op.constant(1) + + with self.assertRaises(ValueError): + with self.test_session() as sess: + sess.run(builtins.dynamic_builtin(len, a)) + + def test_dynamic_len_tf_array(self): + a = constant_op.constant([1, 2, 3]) + + with self.test_session() as sess: + self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_tf_matrix(self): + a = constant_op.constant([[1, 2], [3, 4]]) + + with self.test_session() as sess: + self.assertEqual(2, sess.run(builtins.dynamic_builtin(len, a))) + + def test_dynamic_len_py_list(self): + a = [3] * 5 + + self.assertEqual(5, builtins.dynamic_builtin(len, a)) + + def test_dynamic_print_tf(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', 1)) + self.assertEqual(out_capturer.getvalue(), 'test message 1\n') + finally: + sys.stdout = sys.__stdout__ + + def test_dynamic_print_complex(self): + try: + out_capturer = six.StringIO() + sys.stdout = out_capturer + with self.test_session() as sess: + sess.run(builtins.dynamic_print('test message', [1, 2])) + self.assertEqual(out_capturer.getvalue(), 'test message [1, 2]\n') + finally: + sys.stdout = sys.__stdout__ + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/misc_test.py b/tensorflow/contrib/py2tf/utils/misc_test.py index bfcb304c838df69e9e3961907362c7939c065117..8aedd4cd64798660cc07364c45487399986c9be6 100644 --- a/tensorflow/contrib/py2tf/utils/misc_test.py +++ b/tensorflow/contrib/py2tf/utils/misc_test.py @@ -18,29 +18,29 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.py2tf.utils import misc -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import variables +from tensorflow.contrib.py2tf.utils.misc import alias_tensors +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.ops.variables import Variable from tensorflow.python.platform import test -class ContextManagersTest(test.TestCase): +class MiscTest(test.TestCase): def test_alias_single_tensor(self): - a = constant_op.constant(1) + a = constant(1) - new_a = misc.alias_tensors(a) + new_a = alias_tensors(a) self.assertFalse(new_a is a) with self.test_session() as sess: self.assertEqual(1, sess.run(new_a)) def test_alias_tensors(self): - a = constant_op.constant(1) - v = variables.Variable(2) + a = constant(1) + v = Variable(2) s = 'a' l = [1, 2, 3] - new_a, new_v, new_s, new_l = misc.alias_tensors(a, v, s, l) + new_a, new_v, new_s, new_l = alias_tensors(a, v, s, l) self.assertFalse(new_a is a) self.assertTrue(new_v is v) diff --git a/tensorflow/contrib/py2tf/utils/py_func.py b/tensorflow/contrib/py2tf/utils/py_func.py new file mode 100644 index 0000000000000000000000000000000000000000..838872d092a3ab07e965180eff4fec7ff6c4ccf9 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/py_func.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Pyfunc creation utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import script_ops + + +def wrap_py_func(f, return_dtypes, arguments, use_dummy_return=False): + """Helper that wraps a callable to py_func. + + The helper passes tensor arguments through the py_func interface. Non-tensor + arguments are allowed, and will be passed to f directly. Note that non-tensor + arguments are captured by f will not update every time the wrapper is + called (this is consistent with its argument list, which only includes + the tensor arguments). In general, it's safest not to reuse this wrapper. + + Args: + f: Callable + return_dtypes: DType, tuple, list or None, the data type for each of f's + return value. None if f has no return values or use_dummy_return is + True. + arguments: Arguments for f + use_dummy_return: If True, the function will return a dummy value of 1 + and discard its actual return value. + Returns: + The return values of f converted to tensor. + Raises: + ValueError: if the arguments are incorrect. + """ + + if return_dtypes and use_dummy_return: + raise ValueError('if use_dummy_return is True, return_dtypes must be empty') + + n = len(arguments) + arg_is_tensor = tuple(map(tensor_util.is_tensor, arguments)) + index_in_tensor_list = [0] * n + i = 0 + for j in range(n): + index_in_tensor_list[j] = i + if arg_is_tensor[j]: + i += 1 + + def f_wrapper(*tensor_args): + f_args = tuple(tensor_args[index_in_tensor_list[i]] + if arg_is_tensor[i] else arguments[i] for i in range(n)) + retval = f(*f_args) + return 1 if use_dummy_return else retval + + return script_ops.py_func( + f_wrapper, tuple(arguments[i] for i in range(n) if arg_is_tensor[i]), + dtypes.int64 if use_dummy_return else return_dtypes) diff --git a/tensorflow/contrib/py2tf/utils/py_func_test.py b/tensorflow/contrib/py2tf/utils/py_func_test.py new file mode 100644 index 0000000000000000000000000000000000000000..776b5309c6f027bb2008aa83d48e4155e817ed97 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/py_func_test.py @@ -0,0 +1,91 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for wrap_py_func module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import py_func +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.platform import test + + +class PyFuncTest(test.TestCase): + + def test_wrap_py_func_simple(self): + + def test_fn(a, b, c): + return a + b + c + + with self.test_session() as sess: + tensor_1 = constant_op.constant(1) + self.assertEqual(3, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (1, tensor_1, 1)))) + self.assertEqual(3, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (1, 1, 1)))) + self.assertEqual(3, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (tensor_1, 1, tensor_1)))) + + def test_wrap_py_func_complex_args(self): + + class TestClass(object): + + def __init__(self): + self.foo = 5 + + def test_fn(a, b): + return a * b.foo + + with self.test_session() as sess: + self.assertEqual(35, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (7, TestClass())))) + self.assertEqual( + 35, + sess.run( + py_func.wrap_py_func(test_fn, dtypes.int64, + (constant_op.constant(7), TestClass())))) + + def test_wrap_py_func_dummy_return(self): + + side_counter = [0] + + def test_fn(_): + side_counter[0] += 1 + + with self.test_session() as sess: + self.assertEqual(1, + sess.run( + py_func.wrap_py_func(test_fn, None, (5,), True))) + self.assertEqual([1], side_counter) + self.assertEqual(1, + sess.run( + py_func.wrap_py_func(test_fn, None, + (constant_op.constant(5),), + True))) + self.assertEqual([2], side_counter) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/py2tf/utils/tensor_list.py b/tensorflow/contrib/py2tf/utils/tensor_list.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ff49e2a0eff384f10903e12212ab929e267804 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/tensor_list.py @@ -0,0 +1,49 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A typed list in Python.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.ops import list_ops + + +class TensorList(object): + """Tensor list wrapper API-compatible with Python built-in list.""" + + def __init__(self, shape, dtype): + self.dtype = dtype + self.shape = shape + self.clear() + + def append(self, value): + self.list_ = list_ops.tensor_list_push_back(self.list_, value) + + def pop(self): + self.list_, value = list_ops.tensor_list_pop_back(self.list_, self.dtype) + return value + + def clear(self): + self.list_ = list_ops.empty_tensor_list(self.shape, self.dtype) + + def count(self): + return list_ops.tensor_list_length(self.list_) + + def __getitem__(self, key): + return list_ops.tensor_list_get_item(self.list_, key, self.dtype) + + def __setitem__(self, key, value): + self.list_ = list_ops.tensor_list_set_item(self.list_, key, value) diff --git a/tensorflow/contrib/py2tf/utils/tensor_list_test.py b/tensorflow/contrib/py2tf/utils/tensor_list_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b5e554a162674e08da21785dcbe193c54647f128 --- /dev/null +++ b/tensorflow/contrib/py2tf/utils/tensor_list_test.py @@ -0,0 +1,89 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PyFlow list.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.py2tf.utils import tensor_list as tl +from tensorflow.python.client.session import Session +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework.constant_op import constant +from tensorflow.python.platform import test + + +class TensorListTest(test.TestCase): + + def test_list_append_python(self): + with context.eager_mode(): + a = constant(3.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + self.assertEqual(l.count().numpy(), 1) + l.append(a) + self.assertEqual(l.count().numpy(), 2) + _ = l.pop() + self.assertEqual(l.count().numpy(), 1) + a2 = l.pop() + self.assertEqual(l.count().numpy(), 0) + self.assertEqual(a.numpy(), a2.numpy()) + + def test_list_index_python(self): + with context.eager_mode(): + a = constant(3.0) + b = constant(2.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + self.assertEqual(l[0].numpy(), a.numpy()) + l[0] = ops.convert_to_tensor(b) + self.assertEqual(l[0].numpy(), b.numpy()) + + def test_list_append_tf(self): + a = constant(3.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + c1 = l.count() + l.append(a) + c2 = l.count() + _ = l.pop() + c3 = l.count() + a2 = l.pop() + c4 = l.count() + with Session() as sess: + c1, c2, c3, c4, a, a2 = sess.run([c1, c2, c3, c4, a, a2]) + self.assertEqual(c1, 1) + self.assertEqual(c2, 2) + self.assertEqual(c3, 1) + self.assertEqual(c4, 0) + self.assertEqual(a, a2) + + def test_list_index_tf(self): + a = constant(3.0) + b = constant(2.0) + l = tl.TensorList(a.shape, a.dtype) + l.append(a) + l0 = l[0] + l[0] = b + l1 = l[0] + with self.test_session() as sess: + l0, l1, a, b = sess.run([l0, l1, a, b]) + self.assertEqual(l0, a) + self.assertEqual(l1, b) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/quantization/README.md b/tensorflow/contrib/quantization/README.md new file mode 100644 index 0000000000000000000000000000000000000000..359950aaf3d89c1f3e8fda21cbd27fb89217d918 --- /dev/null +++ b/tensorflow/contrib/quantization/README.md @@ -0,0 +1,7 @@ +The contrib/quantization package exposes a few TensorFlow quantization operations. + +If you are looking for quantized training rewrites that allow for training +quantized models that work with +[TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/), you should look at +the [contrib/quantize](https://www.tensorflow.org/api_docs/python/tf/contrib/quantize) +package. diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md index 40541729da5fd9d0ae75579e11f20999337de124..8b0e7bb68f5a11f5d1942f7cf048e96768da259e 100644 --- a/tensorflow/contrib/quantize/README.md +++ b/tensorflow/contrib/quantize/README.md @@ -1,9 +1,10 @@ +# Quantized Training Rewrites + tf.contrib.quantize provides tools for transforming graphs to include ops to model quantization of weights, biases and activations during both training and inference. This is done using the [fake quantization op] -(https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization), -which is described below: +(https://www.tensorflow.org/versions/r0.12/api_docs/python/array_ops/fake_quantization). Recent literature has shown that fixed point networks provide comparable performance to floating point networks [1]. This is achieved by modeling the @@ -14,56 +15,52 @@ updated at high precision as this is needed to ensure sufficient precision in accumulating tiny adjustments to the parameters. However, for the forward pass, the parameters and activations are quantized to the desired lower precision. -![drawing](g3doc/drawings/Fake_Quantization.jpg) - -###Forward pass - - - - -\begin{equation*} -f_Q(x) = \Delta\text{ }round\left(\frac{sat\left(x\right)-x_{min}}{\Delta}\right) -\end{equation*} - - -where - -$$ -\begin{equation*} -sat(x) = -\left\{ - \begin{array}{ll} - x_{min} & \mbox{if } x \le x_{min} \\ - x & \mbox{if } x_{min} \leq x \leq x_{max} \\ - x_{max} & \mbox{if } x_{max} \le x - \end{array} -\right. -\end{equation*} -$$ - - -where $$\Delta$$ is the Quantizer Step size, given by -$$\Delta =\frac{x_{max} - x_{min} }{255} $$ and $$x_{min} $$ and $$x_{max}$$ are -the minimum and maximum values of the variable under consideration. Note that -the rounding performed is deterministic and corresponds to asymmetric rounding, -which is supported in almost all hardware platforms. - -###Backward pass -For the backward pass, we model the quantizer as a piecewise linear block, with -derivatives that are non-zero only in the linear region. - - - -\begin{equation*} -\frac{df_Q(x)}{dx}=1, x_{min} \leq x \leq x_{max},\text{ 0 elsewhere } -\end{equation*} - -Therefore, the backward pass through the quantizer reduces to passing through -the gradients as long as the inputs to the quantizer are in the linear region. -Otherwise, the gradients are set to zero. - -Note that the quantizer is fully specified by the min and max values of the -variables being quantized. +## How to use the Rewrites + +tf.contrib.quantize provides two rewrites, one to train for quantization and +one to create a [TensorFlow Lite](https://www.tensorflow.org/mobile/tflite/) +compatible eval graph. + +``` +# Build forward pass of model. +… +loss = tf.losses.get_total_loss() + +# Call the training rewrite which rewrites the graph in-place with FakeQuantization nodes +# and folds batchnorm for training. +# It is often needed to finetune a floating point model for quantization with this training tool. +# When training from scratch, quant_delay can be used to activate quantization after +# training to convergence with the float graph, effectively finetuning the model. +tf.contrib.quantize.create_training_graph(quant_delay=2000000) + +# Call backward pass optimizer as usual. +optimizer = tf.train.GradientDescentOptimizer(learning_rate) +optimizer.minimize(loss) +``` + +Additionally, the rewritten eval graph is non-trivially different from the +training graph due the effects of quantization on batch normalization. Thus, +we offer a separate rewrite for the eval_graph. + +``` +# Build eval model +… +logits = tf.nn.softmax_cross_entropy_with_logits(...) + +# Call the eval rewrite which rewrites the graph in-place with FakeQuantization nodes +# and fold batchnorm for eval. +tf.contrib.quantize.create_eval_graph() + +# Save the checkpoint and eval graph proto to disk for freezing and providing to TFLite. +with open(eval_graph_file, ‘w’) as f: + f.write(str(g.as_graph_def())) +saver = tf.train.Saver() +saver.save(sess, checkpoint_name) +``` + +These rewrites are an active area of research and experimentation, so the +rewrites and quantized training will likely not work across all models, though +we hope to work towards generalizing these techniques. [1] P.Gysel, "HARDWARE-ORIENTED APPROXIMATION OF CONVOLUTIONAL diff --git a/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg b/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg deleted file mode 100644 index fdc7ae40cec757cc0a93d50eca6c8698a4697d07..0000000000000000000000000000000000000000 Binary files a/tensorflow/contrib/quantize/g3doc/drawings/Fake_Quantization.jpg and /dev/null differ diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index 36a848d2a8775cf63bd80d4b46eeff7eca97727e..75d9eb0e58d96e4bb2946684febd250e2e1a6b4a 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -34,7 +34,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.util import compat -def FoldBatchNorms(graph, freeze_batch_norm_delay=None, is_training=True): +def FoldBatchNorms(graph, is_training, freeze_batch_norm_delay=None): """Finds batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -42,24 +42,22 @@ def FoldBatchNorms(graph, freeze_batch_norm_delay=None, is_training=True): Args: graph: Graph to walk and modify. + is_training: Bool, true if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. This value is used only when is_training is True. - is_training: Bool, true if training. Raises: ValueError: When batch norm folding fails. """ _FoldFusedBatchNorms( - graph, - freeze_batch_norm_delay=freeze_batch_norm_delay, - is_training=is_training) + graph, is_training, freeze_batch_norm_delay=freeze_batch_norm_delay) _FoldUnfusedBatchNorms( graph, - freeze_batch_norm_delay=freeze_batch_norm_delay, - is_training=is_training) + is_training=is_training, + freeze_batch_norm_delay=freeze_batch_norm_delay) -def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training): +def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds fused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -67,9 +65,9 @@ def _FoldFusedBatchNorms(graph, freeze_batch_norm_delay, is_training): Args: graph: Graph to walk and modify. + is_training: Bool, true if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. - is_training: Bool, true if training. Raises: ValueError: When batch norm folding fails. @@ -416,7 +414,7 @@ def _FoldFusedBatchNormGrad(op, unused_grad_y, grad_mean, grad_var, unused_1, return (dmean_dx + dvar_dx), None, None, None, None -def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training): +def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): """Finds unfused batch norm layers and folds them into preceding layers. Folding only affects the following layers: Conv2D, fully connected, depthwise @@ -424,9 +422,9 @@ def _FoldUnfusedBatchNorms(graph, freeze_batch_norm_delay, is_training): Args: graph: Graph to walk and modify. + is_training: Bool, True if training. freeze_batch_norm_delay: How many steps to wait before freezing moving mean and variance and using them for batch normalization. - is_training: Bool, True if training Raises: ValueError: When batch norm folding fails. diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index e44b91f0d0d336994888ab239034c5cbb62ddeba..5fd806d195dce671d079386ea4b6c89042e26cf6 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -40,16 +40,17 @@ _WEIGHT_TYPES = {'Variable', 'VariableV2'} def Quantize(graph, + is_training, weight_bits=8, activation_bits=8, ema_decay=0.999, quant_delay=None, - vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True): + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES): """Updates graph with quantization operations. Args: graph: Graph to modify. + is_training: Whether quantizing training graph or eval graph. weight_bits: Number of bits to use for quantizing weights. activation_bits: Number of bits to use for quantizing activations. ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update @@ -60,7 +61,6 @@ def Quantize(graph, training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. - is_training: (Optional) Whether quantizing training graph or eval graph. Raises: ValueError: When quantization fails. """ @@ -70,15 +70,15 @@ def Quantize(graph, context = _GetContextFromOp(layer_match.layer_op) _InsertQuantOp( context, + 'weights_quant', layer_match.weight_tensor.op, [layer_match.layer_op], - name='weights_quant', + is_training, moving_avg=False, - bits=weight_bits, ema_decay=ema_decay, quant_delay=quant_delay, - is_training=is_training, narrow_range=True, - vars_collection=vars_collection) + vars_collection=vars_collection, + bits=weight_bits) # Quantize the activations. consumer_ops = input_to_ops_map.ConsumerOperations( @@ -88,23 +88,25 @@ def Quantize(graph, add_context = re.search(r'^(.*)/([^/]+)', context).group(1) _InsertQuantOp( add_context, + 'act_quant', layer_match.activation_op, consumer_ops, - name='act_quant', + is_training, moving_avg=True, - init_min=0.0, ema_decay=ema_decay, quant_delay=quant_delay, + vars_collection=vars_collection, bits=activation_bits, - vars_collection=vars_collection) + init_min=0.0) # Quantize the inputs and output to the bypass (if it exists). The input to # the bypass is the bias add, and the output is the activation. if layer_match.bypass_op is not None: _InsertQuantOp( context, + 'conv_quant', layer_match.bias_add_op, [layer_match.bypass_op], - name='conv_quant', + is_training, moving_avg=True, ema_decay=ema_decay, quant_delay=quant_delay, @@ -112,10 +114,14 @@ def Quantize(graph, bits=activation_bits) _InsertQuantOp( add_context, + 'add_quant', layer_match.bypass_op, input_to_ops_map.ConsumerOperations(layer_match.bypass_op), - name='add_quant', + is_training, moving_avg=True, + ema_decay=ema_decay, + quant_delay=quant_delay, + vars_collection=vars_collection, bits=activation_bits) @@ -201,6 +207,18 @@ def _FindLayersToQuantize(graph): yield _LayerMatch(layer_op, weight_tensor, activation_op, bypass_op, bias_add_op) + # Match the final layer, where there will not be an activation and instead + # the output of the final BiasAdd must be quantized, so we treat it as the + # 'activation_op' in the _LayerMatch. + # TODO(suharshs): Figure out how to quantize this final layer across many + # models. + final_layer_matcher = graph_matcher.GraphMatcher(bias_add_pattern) + for match_result in final_layer_matcher.match_graph(graph): + layer_op = match_result.get_op(layer_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + activation_op = match_result.get_op(bias_add_pattern) + yield _LayerMatch(layer_op, weight_tensor, activation_op, None, None) + class _LayerMatch(object): """Contains all information related to a matched Layer.""" @@ -235,9 +253,10 @@ class _LayerMatch(object): def _InsertQuantOp(context, + name, producer, consumers, - name, + is_training, moving_avg=True, init_min=-6.0, init_max=6.0, @@ -245,16 +264,16 @@ def _InsertQuantOp(context, ema_decay=0.999, quant_delay=None, vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, - is_training=True, narrow_range=False): """Inserts a quant op between a producer op and (multiple) consumer ops. Args: context: Context w,here producer and consumer operations are nested. + name: Name for the new quantization op within the context. producer: Producer operation of the pairs where quantization will be inserted. consumers: Consumer operations of the pairs. - name: Name for the new quantization op within the context. + is_training: Whether quantizing training graph or eval graph. moving_avg: Specifies whether to use exponential moving average or just the last value seen. init_min: Starting minimum value for the new quantization op. @@ -268,7 +287,6 @@ def _InsertQuantOp(context, training. vars_collection: (Optional) Collection where to store the variables for quantization interval ends. - is_training: (Optional) Whether quantizing training graph or eval graph. narrow_range: Whether to use the narrow quantization range [1; 2^bits - 1] or wide range [0; 2^bits - 1]. Raises: diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py index b91e0451755183cb2e143fdee278fee1c363db97..5a3a74cec4864ad3808d485849334c81f569d300 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph.py +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -63,13 +63,13 @@ def _create_graph(input_graph=None, is_training=is_training) quantize.Quantize( input_graph, - is_training=is_training, + is_training, quant_delay=quant_delay, weight_bits=weight_bits, activation_bits=activation_bits) -def create_training_graph(input_graph=None, quant_delay=250000): +def create_training_graph(input_graph=None, quant_delay=0): """Rewrites a training input_graph in place for simulated quantization. The graph has fake quantization ops inserted to simulate the error @@ -77,6 +77,14 @@ def create_training_graph(input_graph=None, quant_delay=250000): the expected behavior of previously held references to nodes and tensors may change. + The default value of quant_delay is suitable for finetuning an already trained + floating point model (recommended). + If one wants to train a quantized model from scratch, quant_delay should be + set to the number of steps it take the floating point model to converge. + Quantization will be activated at this point and effectively finetune the + model. If quant_delay is not provided when training from scratch, training can + often fail. + Args: input_graph: The tf.Graph to be transformed. quant_delay: Number of steps after which weights and activations are @@ -93,12 +101,12 @@ def create_training_graph(input_graph=None, quant_delay=250000): # Corresponds to case of restoring from a floating point checkpoint # In this case, we can freeze the moving mean and variance early on and # switch to using them during training. Therefore, freeze_bn_delay is set to - # 200000 - freeze_bn_delay = 200000 + # 2e5. + freeze_bn_delay = int(2e5) else: # If training from scratch, set freeze_bn_delay to 100 epochs after quant # delay. With a batch size of 64, this corresponds to 20000*100=2M steps. - freeze_bn_delay = quant_delay + 2000000 + freeze_bn_delay = quant_delay + int(2e6) _create_graph( input_graph=input_graph, @@ -129,8 +137,8 @@ def create_eval_graph(input_graph=None): def experimental_create_training_graph(input_graph=None, weight_bits=8, activation_bits=8, - quant_delay=250000, - freeze_bn_delay=500000): + quant_delay=0, + freeze_bn_delay=int(2e5)): """Rewrites a training input_graph in place for simulated quantization. This function has additional experimental options not (yet) available to @@ -141,6 +149,14 @@ def experimental_create_training_graph(input_graph=None, the expected behavior of previously held references to nodes and tensors may change. + The default value of quant_delay is suitable for finetuning an already trained + floating point model (recommended). + If one wants to train a quantized model from scratch, quant_delay should be + set to the number of steps it take the floating point model to converge. + Quantization will be activated at this point and effectively finetune the + model. If quant_delay is not provided when training from scratch, training can + often fail. + Args: input_graph: The tf.Graph to be transformed,if None then defaults to the default graph. diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py index c57fcd4e4e6ef9e31890baa3339c8c7037d47874..b9d03c1bc059fe7bcce75978f503cbbf76090dbd 100644 --- a/tensorflow/contrib/quantize/python/quantize_graph_test.py +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -28,13 +28,11 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest -# TODO(suharshs): Add tests for testing experimental APIs and additional -# input arguments class QuantizeGraphTest(test_util.TensorFlowTestCase): # We have a lot of other tests that test the details of the rewrite, here we # just the specific features of the quantize_graph API. - def _RunTestOverParameters(self, test_fn): + def _RunTestOverAllRewrites(self, test_fn): rewrite_fns = [ quantize_graph.create_training_graph, quantize_graph.create_eval_graph, @@ -44,71 +42,202 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase): for fn in rewrite_fns: test_fn(fn) + def _RunTestOverTrainingRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.create_training_graph, + quantize_graph.experimental_create_training_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + + def _RunTestOverEvalRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.create_eval_graph, + quantize_graph.experimental_create_eval_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + + def _RunTestOverExperimentalRewrites(self, test_fn): + rewrite_fns = [ + quantize_graph.experimental_create_training_graph, + quantize_graph.experimental_create_eval_graph, + ] + for fn in rewrite_fns: + test_fn(fn) + def testRewrite(self): - self._RunTestOverParameters(self._TestRewrite) + self._RunTestOverAllRewrites(self._TestRewrite) - def _TestRewrite(self, fn): + def _TestRewrite(self, rewrite_fn): graph = ops.Graph() with graph.as_default(): - batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) - conv = layers.conv2d( - inputs, - 32, [5, 5], - stride=2, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - scope='test') - _ = nn_ops.relu6(conv) + self._ConvLayer() orig_variable_names = set( [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - fn(graph) + rewrite_fn(graph) q_variables = graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) # Ensure that variables were added. self.assertTrue(len(orig_variable_names) < len(q_variables)) def testDefaultGraph(self): - self._RunTestOverParameters(self._TestRewrite) + self._RunTestOverAllRewrites(self._TestRewrite) - def _TestDefaultGraph(self, fn): + def _TestDefaultGraph(self, rewrite_fn): + # Tests that the default graph is correctly used when no args are provided + # to rewrite_fn. with ops.Graph().as_default() as g: - batch_size, height, width, depth = 5, 128, 128, 3 - inputs = array_ops.zeros((batch_size, height, width, depth)) - conv = layers.conv2d( - inputs, - 32, [5, 5], - stride=2, - padding='SAME', - weights_initializer=self._WeightInit(0.09), - activation_fn=None, - scope='test') - _ = nn_ops.relu6(conv) - + self._ConvLayer() orig_variable_names = set( [v.name for v in g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) - - fn() + rewrite_fn() q_variables = g.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) # Ensure that variables were added. self.assertTrue(len(orig_variable_names) < len(q_variables)) - def _WeightInit(self, stddev): - """Returns truncated normal variable initializer. - - Function is defined purely to shorten the name so that it stops wrapping. + def testQuantDelay(self): + self._RunTestOverTrainingRewrites(self._TestQuantDelay) - Args: - stddev: Standard deviation of normal variable. - - Returns: - An initialized that initialzes with a truncated normal variable. - """ - return init_ops.truncated_normal_initializer(stddev=stddev) + def _TestQuantDelay(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + quant_delay = 100 + rewrite_fn(quant_delay=quant_delay) + + quant_delay_found = False + for op in g.get_operations(): + # Check to see if the quant_delay is correctly set. + if 'activate_quant' in op.name and op.type == 'Const': + quant_delay_found = True + const_value = str(op.get_attr('value')) + self.assertTrue(('int64_val: %i' % quant_delay) in const_value) + self.assertTrue(quant_delay_found) + + def testWeightBits(self): + self._RunTestOverExperimentalRewrites(self._TestWeightBits) + + def _TestWeightBits(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + weight_bits = 4 + rewrite_fn(weight_bits=weight_bits) + + weights_quant_found = False + for op in g.get_operations(): + # Check to see if FakeQuant operations for weights have the right bits + # set. + if 'weights_quant' in op.name and op.type == 'FakeQuantWithMinMaxVars': + weights_quant_found = True + self.assertEqual(op.get_attr('num_bits'), weight_bits) + self.assertTrue(weights_quant_found) + + def testActivationBits(self): + self._RunTestOverExperimentalRewrites(self._TestActivationBits) + + def _TestActivationBits(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + activation_bits = 4 + rewrite_fn(activation_bits=activation_bits) + + act_quant_found = False + for op in g.get_operations(): + # Check to see if FakeQuant operations for activations have the right bits + # set. + act_quant_names = ['act_quant', 'conv_quant', 'add_quant'] + if any(s in op.name + for s in act_quant_names) and op.type == 'FakeQuantWithMinMaxVars': + act_quant_found = True + self.assertEqual(op.get_attr('num_bits'), activation_bits) + self.assertTrue(act_quant_found) + + def testTrainingQuantization(self): + self._RunTestOverTrainingRewrites(self._TestTrainingQuantization) + + def _TestTrainingQuantization(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + rewrite_fn() + + # Ensure that FakeQuant and variable update nodes were found. + quant_found = False + assign_min_last_found = False + assign_min_ema_found = False + assign_max_last_found = False + assign_max_ema_found = False + for op in g.get_operations(): + # Check that FakeQuant operations were added. + if op.type == 'FakeQuantWithMinMaxVars': + quant_found = True + # Check that update operations for the added min max variables exist in + # the graph. + if 'AssignMinLast' in op.name: + assign_min_last_found = True + elif 'AssignMinEma' in op.name: + assign_min_ema_found = True + elif 'AssignMaxLast' in op.name: + assign_max_last_found = True + elif 'AssignMaxEma' in op.name: + assign_max_ema_found = True + self.assertTrue(assign_min_last_found) + self.assertTrue(assign_min_ema_found) + self.assertTrue(assign_max_last_found) + self.assertTrue(assign_max_ema_found) + self.assertTrue(quant_found) + + def testEvalQuantization(self): + self._RunTestOverEvalRewrites(self._TestEvalQuantization) + + def _TestEvalQuantization(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + rewrite_fn() + + # Ensure that FakeQuant and variable update nodes were found. + quant_found = False + for op in g.get_operations(): + # Check that FakeQuant operations were added. + if op.type == 'FakeQuantWithMinMaxVars': + quant_found = True + # Check that update operations for the added min max variables don't + # exist in the graph. + update_names = [ + 'AssignMinLast', 'AssignMinEma', 'AssignMaxLast', 'AssignMaxEma' + ] + self.assertFalse(any(s in op.name for s in update_names)) + self.assertTrue(quant_found) + + def testIdempotent(self): + self._RunTestOverAllRewrites(self._TestIdempotent) + + def _TestIdempotent(self, rewrite_fn): + with ops.Graph().as_default() as g: + self._ConvLayer() + rewrite_fn() + graph_def_before = str(g.as_graph_def()) + # Ensuring that calling the rewrite again doesn't add more nodes. + rewrite_fn() + graph_def_after = str(g.as_graph_def()) + self.assertEqual(graph_def_before, graph_def_after) + + def _ConvLayer(self): + """Add a basic convolution layer to the default graph.""" + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + weight_init = init_ops.truncated_normal_initializer + conv = layers.conv2d( + inputs, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=weight_init(0.09), + activation_fn=None, + scope='test') + _ = nn_ops.relu6(conv) if __name__ == '__main__': diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py index 2e74f3b04dc4f53222c769a9b3e48a1b1338ba15..639a7454a92aebd7289c59498cebff82cc003f75 100644 --- a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -88,7 +88,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) @@ -164,7 +164,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + @@ -240,7 +240,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + @@ -363,9 +363,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - fold_batch_norms.FoldBatchNorms(graph) + fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + @@ -446,9 +446,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - fold_batch_norms.FoldBatchNorms(graph) + fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + @@ -534,9 +534,9 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - fold_batch_norms.FoldBatchNorms(graph) + fold_batch_norms.FoldBatchNorms(graph, is_training=True) - quantize.Quantize(graph, quant_delay=delay) + quantize.Quantize(graph, True, quant_delay=delay) quantization_node_name = 'FakeQuantWithMinMaxVars' weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + quantization_node_name) diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py index 53cbd667410fe50336601c357c16754dff6596c8..ef59475167137e203db2f6ca7f43c7b8f1938060 100644 --- a/tensorflow/contrib/quantize/python/quantize_test.py +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -35,7 +35,15 @@ separable_conv2d = layers.separable_conv2d class QuantizeTest(test_util.TensorFlowTestCase): + def _RunTestOverParameters(self, test_fn): + params = [True, False] + for is_training in params: + test_fn(is_training) + def testInsertQuantOpFailsWhenOpsNotConnected(self): + pass + + def _TestInsertQuantOpFailsWhenOpsNotConnected(self, is_training): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -48,11 +56,15 @@ class QuantizeTest(test_util.TensorFlowTestCase): # Inserting a quantization op between two unconnected ops should fail with # ValueError. with self.assertRaises(ValueError) as err: - quantize._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') + quantize._InsertQuantOp('test', is_training, conv.op, [relu.op], + 'FailingQuantOp') self.assertEqual( str(err.exception), 'Some inputs not quantized for ops: [Relu6]') def testInsertQuantOpForAddAfterConv2d(self): + self._RunTestOverParameters(self._TestInsertQuantOpForAddAfterConv2d) + + def _TestInsertQuantOpForAddAfterConv2d(self, is_training): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -67,7 +79,7 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8) + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + @@ -75,6 +87,10 @@ class QuantizeTest(test_util.TensorFlowTestCase): self.assertEqual(add_quant.type, quantization_node_name) def testInsertQuantOpForAddAfterSeparableConv2d(self): + self._RunTestOverParameters( + self._TestInsertQuantOpForAddAfterSeparableConv2d) + + def _TestInsertQuantOpForAddAfterSeparableConv2d(self, is_training): graph = ops.Graph() with graph.as_default(): batch_size, height, width, depth = 5, 128, 128, 3 @@ -90,13 +106,35 @@ class QuantizeTest(test_util.TensorFlowTestCase): with ops.control_dependencies([update_barrier]): array_ops.identity(node, name='control_dependency') - quantize.Quantize(graph=graph, weight_bits=8, activation_bits=8) + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) quantization_node_name = 'FakeQuantWithMinMaxVars' add_quant = graph.get_operation_by_name('test/add_quant/' + quantization_node_name) self.assertEqual(add_quant.type, quantization_node_name) + def testFinalLayerQuantized(self): + self._RunTestOverParameters(self._TestFinalLayerQuantized) + + def _TestFinalLayerQuantized(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + _ = conv2d( + input1, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + # Ensure that the a FakeQuant operation is in the outputs of the BiasAdd. + bias_add_op = graph.get_operation_by_name('test/BiasAdd') + quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8) + self.assertTrue('FakeQuantWithMinMaxVars' in + [op.type for op in bias_add_op.outputs[0].consumers()]) + def _WeightInit(self, stddev): """Returns truncated normal variable initializer. diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index f7007173943c99d08791c125b906d4befe6387ea..4eb4fbcd92f0d7cb3bee712862c8950a1971b632 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -572,9 +572,8 @@ class LSTMBlockWrapper(base_layer.Layer): def _gather_states(self, data, indices, batch_size): """Produce `out`, s.t. out(i, j) = data(indices(i), i, j).""" - mod_indices = indices * batch_size + math_ops.range(batch_size) - return array_ops.gather( - array_ops.reshape(data, [-1, self.num_units]), mod_indices) + return array_ops.gather_nd( + data, array_ops.stack([indices, math_ops.range(batch_size)], axis=1)) class LSTMBlockFusedCell(LSTMBlockWrapper): diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index dce71c393aa14a026203c1da09635df5d08eb46f..a6c2d9cdbb2b6f61d59960f708000e945c6115e9 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -424,8 +424,9 @@ class TimeFreqLSTMCell(rnn_cell_impl.RNNCell): "W_O_diag", shape=[self._num_units], dtype=dtype) # initialize the first freq state to be zero - m_prev_freq = array_ops.zeros([int(inputs.get_shape()[0]), self._num_units], - dtype) + m_prev_freq = array_ops.zeros( + [inputs.shape[0].value or inputs.get_shape()[0], self._num_units], + dtype) for fq in range(len(freq_inputs)): c_prev = array_ops.slice(state, [0, 2 * fq * self._num_units], [-1, self._num_units]) diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index 64973ccccdc962757a727d7183bd70e94edcfd1b..dfa12e873a6aca806031c48d6f92e0432d0ea6e0 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -80,12 +80,12 @@ class GatherTreeOp : public OpKernel { max_sequence_lengths.shape().DebugString())); Tensor* beams; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); - typename TTypes::ConstTensor step_ids_t = step_ids.tensor(); - typename TTypes::ConstTensor parent_ids_t = parent_ids.tensor(); + typename TTypes::ConstTensor step_ids_t(step_ids.tensor()); + typename TTypes::ConstTensor parent_ids_t(parent_ids.tensor()); typename TTypes::ConstVec max_seq_lens_t = max_sequence_lengths.vec(); - typename TTypes::ConstScalar end_token_t = end_token.scalar(); - typename TTypes::Tensor beams_t = beams->tensor(); + typename TTypes::ConstScalar end_token_t(end_token.scalar()); + typename TTypes::Tensor beams_t(beams->tensor()); const T end_token_value = end_token_t(); functor::GatherTree()(ctx, device, step_ids_t, parent_ids_t, max_seq_lens_t, end_token_value, beams_t); diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index d6184d61095f727f9dcab56fe59e2601868c1624..554eb24e5260724a905b099091bf8aea461554cf 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -724,7 +724,7 @@ def _mask_probs(probs, eos_token, finished): eos_token, vocab_size, dtype=probs.dtype, - on_value=0., + on_value=ops.convert_to_tensor(0., dtype=probs.dtype), off_value=probs.dtype.min) finished_probs = array_ops.tile( array_ops.reshape(finished_row, [1, 1, -1]), diff --git a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py index ad5e985487190e72b9eb2809da964f3d7b34ef94..b3343aef47d9f352c3bcbef4afbe8f9bf2560e6d 100644 --- a/tensorflow/contrib/slim/python/slim/data/parallel_reader.py +++ b/tensorflow/contrib/slim/python/slim/data/parallel_reader.py @@ -221,7 +221,7 @@ def parallel_read(data_sources, the data will be cycled through indefinitely. num_readers: a integer, number of Readers to create. reader_kwargs: an optional dict, of kwargs for the reader. - shuffle: boolean, wether should shuffle the files and the records by using + shuffle: boolean, whether should shuffle the files and the records by using RandomShuffleQueue as common_queue. dtypes: A list of types. The length of dtypes must equal the number of elements in each record. If it is None it will default to diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index 0544404e9e252cca6d3650b805b91be25d705eea..b3b61e1dfe5671a7fbbee20b0c577ee5fad0fb9b 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -349,7 +349,8 @@ class Image(ItemHandler): shape=None, channels=3, dtype=dtypes.uint8, - repeated=False): + repeated=False, + dct_method=''): """Initializes the image. Args: @@ -368,6 +369,11 @@ class Image(ItemHandler): tf.decode_raw, repeated: if False, decodes a single image. If True, decodes a variable number of image strings from a 1D tensor of strings. + dct_method: An optional string. Defaults to empty string. It only takes + effect when image format is jpeg, used to specify a hint about the + algorithm used for jpeg decompression. Currently valid values + are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for + example, the jpeg library does not have that specific option. """ if not image_key: image_key = 'image/encoded' @@ -381,6 +387,7 @@ class Image(ItemHandler): self._channels = channels self._dtype = dtype self._repeated = repeated + self._dct_method = dct_method def tensors_to_item(self, keys_to_tensors): """See base class.""" @@ -406,9 +413,25 @@ class Image(ItemHandler): A tensor that represents decoded image of self._shape, or (?, ?, self._channels) if self._shape is not specified. """ + def decode_image(): - """Decodes a png or jpg based on the headers.""" - return image_ops.decode_image(image_buffer, self._channels) + """Decodes a image based on the headers.""" + return image_ops.decode_image(image_buffer, channels=self._channels) + + def decode_jpeg(): + """Decodes a jpeg image with specified '_dct_method'.""" + return image_ops.decode_jpeg( + image_buffer, channels=self._channels, dct_method=self._dct_method) + + def check_jpeg(): + """Checks if an image is jpeg.""" + # For jpeg, we directly use image_ops.decode_jpeg rather than decode_image + # in order to feed the jpeg specify parameter 'dct_method'. + return control_flow_ops.cond( + image_ops.is_jpeg(image_buffer), + decode_jpeg, + decode_image, + name='cond_jpeg') def decode_raw(): """Decodes a raw image.""" @@ -420,7 +443,7 @@ class Image(ItemHandler): math_ops.equal(image_format, 'RAW')): decode_raw, } image = control_flow_ops.case( - pred_fn_pairs, default=decode_image, exclusive=True) + pred_fn_pairs, default=check_jpeg, exclusive=True) image.set_shape([None, None, self._channels]) if self._shape is not None: diff --git a/tensorflow/contrib/slim/python/slim/evaluation_test.py b/tensorflow/contrib/slim/python/slim/evaluation_test.py index 7ab6805fac631d6f6b475c4c91f7e3873e7ffea5..c24bd048512daaae116e732ac437f7c9b6f6d7fc 100644 --- a/tensorflow/contrib/slim/python/slim/evaluation_test.py +++ b/tensorflow/contrib/slim/python/slim/evaluation_test.py @@ -29,6 +29,7 @@ from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.metrics.python.ops import metric_ops from tensorflow.contrib.slim.python.slim import evaluation from tensorflow.contrib.training.python.training import evaluation as evaluation_lib +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.debug.lib import debug_data from tensorflow.python.debug.wrappers import hooks from tensorflow.python.framework import constant_op @@ -235,7 +236,7 @@ class SingleEvaluationTest(test.TestCase): def _prepareCheckpoint(self, checkpoint_path): init_op = control_flow_ops.group(variables.global_variables_initializer(), variables.local_variables_initializer()) - saver = saver_lib.Saver() + saver = saver_lib.Saver(write_version=saver_pb2.SaverDef.V1) with self.test_session() as sess: sess.run(init_op) saver.save(sess, checkpoint_path) diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index 068ae35c712622117127bb5b3dfa341a48254c54..b6249fc92f712b21197c2167fb5d1c4af1f48ca5 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -110,7 +110,7 @@ class SummaryWriter(object): def __init__(self, resource): self._resource = resource - if context.in_eager_mode(): + if context.in_eager_mode() and self._resource is not None: self._resource_deleter = resource_variable_ops.EagerResourceDeleter( handle=self._resource, handle_device="cpu:0") diff --git a/tensorflow/contrib/summary/summary_test_internal.py b/tensorflow/contrib/summary/summary_test_internal.py new file mode 100644 index 0000000000000000000000000000000000000000..d0d3384735fb1eb1a048c7aa6da0037ee9fc6936 --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_internal.py @@ -0,0 +1,60 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal helpers for tests in this directory.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import sqlite3 + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.framework import test_util + + +class SummaryDbTest(test_util.TensorFlowTestCase): + """Helper for summary database testing.""" + + def setUp(self): + super(SummaryDbTest, self).setUp() + self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite') + if os.path.exists(self.db_path): + os.unlink(self.db_path) + self.db = sqlite3.connect(self.db_path) + self.create_db_writer = functools.partial( + summary_ops.create_db_writer, + db_uri=self.db_path, + experiment_name='experiment', + run_name='run', + user_name='user') + + def tearDown(self): + self.db.close() + super(SummaryDbTest, self).tearDown() + + +def get_one(db, q, *p): + return db.execute(q, p).fetchone()[0] + + +def get_all(db, q, *p): + return unroll(db.execute(q, p).fetchall()) + + +def unroll(list_of_tuples): + return sum(list_of_tuples, ()) diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py index bda57e6a0ca8e1ddb979a80de276911c7738f0aa..8506c4be9c4ca8305b62da17c7246e6e18313bd3 100644 --- a/tensorflow/contrib/summary/summary_test_util.py +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -21,6 +21,7 @@ from __future__ import print_function import functools import os + import sqlite3 from tensorflow.contrib.summary import summary_ops diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index 58a7fa095d8356229fdb5879bea99d316113c828..1e4cc3f0952ef74a1c89b7ed2d8c357fa8847ad5 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -497,6 +497,7 @@ py_library( ":tensor_forest_v4_ops_py", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/tensor_forest/proto:fertile_stats_proto_py", "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", diff --git a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h index 04e6b0a735320dd024e326a94ef910593a326245..dc3e9fe79d32a19930d500b62b520eddb4b41aa8 100644 --- a/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h +++ b/tensorflow/contrib/tensor_forest/kernels/v4/grow_stats.h @@ -468,7 +468,7 @@ class FixedSizeSparseClassificationGrowStats : public ClassificationStats { void PackToProto(FertileSlot* slot) const override; void InitLeafClassStats(int best_split_index, LeafStat* left_stats, - LeafStat* right_stats) const; + LeafStat* right_stats) const override; protected: void ClassificationAddSplitStats() override { diff --git a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc index 3868b1172f21d3ba5a3fd1c71525207eaeb06304..85b3e7231bcb433e9510522597c03c5f764f06cf 100644 --- a/tensorflow/contrib/tensorboard/db/summary_file_writer.cc +++ b/tensorflow/contrib/tensorboard/db/summary_file_writer.cc @@ -47,9 +47,9 @@ class SummaryFileWriter : public SummaryWriterInterface { mutex_lock ml(mu_); events_writer_ = tensorflow::MakeUnique(io::JoinPath(logdir, "events")); - if (!events_writer_->InitWithSuffix(filename_suffix)) { - return errors::Unknown("Could not initialize events writer."); - } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + events_writer_->InitWithSuffix(filename_suffix), + "Could not initialize events writer."); last_flush_ = env_->NowMicros(); is_initialized_ = true; return Status::OK(); @@ -151,9 +151,8 @@ class SummaryFileWriter : public SummaryWriterInterface { events_writer_->WriteEvent(*e); } queue_.clear(); - if (!events_writer_->Flush()) { - return errors::InvalidArgument("Could not flush events file."); - } + TF_RETURN_WITH_CONTEXT_IF_ERROR(events_writer_->Flush(), + "Could not flush events file."); last_flush_ = env_->NowMicros(); return Status::OK(); } diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index d62bca353ab3d20751774ff540d3a07a83ec11ab..1010a8988d7a47a99662cd6a855ae128eda349f8 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -46,18 +46,18 @@ tf_cuda_cc_test( ) tf_custom_op_library( - name="python/ops/_trt_engine_op.so", - srcs=[ - "ops/trt_calib_op.cc", - "ops/trt_engine_op.cc", - ], - deps=[ - ":trt_engine_op_kernel", - ":trt_shape_function", - "//tensorflow/core:lib_proto_parsing", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), + name = "python/ops/_trt_engine_op.so", + srcs = [ + "ops/trt_calib_op.cc", + "ops/trt_engine_op.cc", + ], + deps = [ + ":trt_engine_op_kernel", + ":trt_shape_function", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), ) tf_cuda_library( @@ -73,37 +73,38 @@ tf_cuda_library( ) cc_library( - name="trt_engine_op_kernel", - srcs=[ - "kernels/trt_calib_op.cc", - "kernels/trt_engine_op.cc", - ], - hdrs=[ - "kernels/trt_calib_op.h", - "kernels/trt_engine_op.h", - ], - copts=tf_copts(), - deps=[ - ":trt_logging", - ":trt_resources", - "//tensorflow/core:gpu_headers_lib", - "//tensorflow/core:lib_proto_parsing", - "//tensorflow/core:stream_executor_headers_lib", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), - alwayslink=1, - visibility=["//visibility:public"], + name = "trt_engine_op_kernel", + srcs = [ + "kernels/trt_calib_op.cc", + "kernels/trt_engine_op.cc", + ], + hdrs = [ + "kernels/trt_calib_op.h", + "kernels/trt_engine_op.h", + ], + copts = tf_copts(), + deps = [ + ":trt_logging", + ":trt_resources", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), + visibility = ["//visibility:public"], + # TODO(laigd) + alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs ) tf_gen_op_libs( - op_lib_names=[ - "trt_engine_op", - "trt_calib_op", - ], - deps=if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]), + op_lib_names = [ + "trt_engine_op", + "trt_calib_op", + ], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), ) tf_cuda_library( @@ -119,13 +120,13 @@ tf_cuda_library( ) tf_gen_op_wrapper_py( - name="trt_engine_op", - deps=[ - ":trt_engine_op_op_lib", - ":trt_calib_op_op_lib", - ":trt_logging", - ":trt_shape_function", - ], + name = "trt_engine_op", + deps = [ + ":trt_calib_op_op_lib", + ":trt_engine_op_op_lib", + ":trt_logging", + ":trt_shape_function", + ], ) tf_custom_op_py_library( @@ -205,37 +206,58 @@ tf_cuda_library( ], ) +tf_cuda_library( + name = "trt_resources", + srcs = [ + "resources/trt_int8_calibrator.cc", + "resources/trt_resource_manager.cc", + ], + hdrs = [ + "resources/trt_int8_calibrator.h", + "resources/trt_resource_manager.h", + "resources/trt_resources.h", + ], + deps = [ + ":trt_logging", + "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + # Library for the node-level conversion portion of TensorRT operation creation tf_cuda_library( - name="trt_conversion", - srcs=[ - "convert/convert_graph.cc", - "convert/convert_nodes.cc", - ], - hdrs=[ - "convert/convert_graph.h", - "convert/convert_nodes.h", - ], - deps=[ - ":segment", - ":trt_logging", - ":trt_resources", - "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/grappler:utils", - "//tensorflow/core:framework", - "//tensorflow/core:framework_lite", - "//tensorflow/core:graph", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/grappler:devices", - "//tensorflow/core/grappler/clusters:virtual_cluster", - "//tensorflow/core/grappler/costs:graph_properties", - "//tensorflow/core/grappler/optimizers:constant_folding", - "//tensorflow/core/grappler/optimizers:layout_optimizer", - ] + if_tensorrt([ - "@local_config_tensorrt//:nv_infer", - ]) + tf_custom_op_library_additional_deps(), + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + ], + deps = [ + ":segment", + ":trt_logging", + ":trt_resources", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:constant_folding", + "//tensorflow/core/grappler/optimizers:layout_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), ) # Library for the segmenting portion of TensorRT operation creation diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index 638fdebcac61044c3c1ca04e6a7dd74f06e49e70..d753e272f4445d4ee1dc7c9b6f4ac05327b7c909 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#include "tensorflow/core/protobuf/device_properties.pb.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" // NOLINT //#if GOOGLE_CUDA //#if GOOGLE_TENSORRT diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc index d0c7e004282c9900704163c5102908624871dc0d..1dcb87e7683ad73b1f5f894b61a15a16d36cfcdf 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc +++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.cc @@ -14,10 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/tensorrt/kernels/trt_calib_op.h" -#include "tensorflow/contrib/tensorrt/resources/TRTInt8Calibrator.h" -#include "tensorflow/contrib/tensorrt/resources/TRTResourceManager.h" -#include "tensorflow/contrib/tensorrt/resources/TRTResources.h" -#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/contrib/tensorrt/resources/trt_resources.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_types.h" @@ -29,108 +28,102 @@ limitations under the License. #include "tensorrt/include/NvInfer.h" namespace tensorflow { -namespace trt { +namespace tensorrt { + TRTCalibOp::TRTCalibOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("segment_nodes", &segment_nodes_)); OP_REQUIRES_OK(context, context->GetAttr("input_names", &input_names_)); - OP_REQUIRES_OK(context, context->GetAttr("resource_name", &repo_name)); + OP_REQUIRES_OK(context, context->GetAttr("resource_name", &resource_name_)); }; -#define TYPECASE(dt, X, Y) \ - case dt: { \ - Y = (void*)X->flat::Type>().data(); \ - break; \ +#define TYPECASE(dt, X, Y) \ + case dt: { \ + return (void*)X->flat::Type>().data(); \ } -#define GET_TENSOR_ADDRESS(tensor_ptr, dest_ptr) \ - { \ - auto TENSOR_TYPE = tensor_ptr->dtype(); \ - switch (TENSOR_TYPE) { \ - TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); \ - TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); \ - TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); \ - default: { \ - LOG(FATAL) << "Unsupported Data type " \ - << tensorflow::DataTypeString(TENSOR_TYPE); \ - break; \ - } \ - } \ + +void* GetTensorAddress(const Tensor* tensor_ptr) { + auto tensor_type = tensor_ptr->dtype(); + switch (tensor_type) { + TYPECASE(tensorflow::DT_FLOAT, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_HALF, tensor_ptr, dest_ptr); + TYPECASE(tensorflow::DT_INT8, tensor_ptr, dest_ptr); + default: { + LOG(FATAL) << "Unsupported Data type " + << tensorflow::DataTypeString(tensor_type); + return nullptr; + } } +} + void TRTCalibOp::Compute(tensorflow::OpKernelContext* ctx) { - auto trt_rm = tensorflow::trt::TRTResourceManager::instance(); - VLOG(2) << "Op Name= " << name() << " nodedef name= " << repo_name; - auto resmgr = trt_rm->getManager("TRTCalibOps"); - tensorflow::trt::TRTCalibrationResource* calibRes = nullptr; - auto status = resmgr->Lookup(repo_name, repo_name, &calibRes); - if (status.ok()) { - int numInputs = ctx->num_inputs(); - if (calibRes->calibrator == nullptr) { - dev_tensors_.resize(numInputs); - int batchSize = ctx->input(0).dim_size(0); - VLOG(1) << " Constructing calibrator"; - // first run - for (int i = 0; i < numInputs; i++) { - const tensorflow::Tensor& t = ctx->input(i); - VLOG(1) << "Tensor " << i << " " << t.shape().DebugString(); - OP_REQUIRES_OK(ctx, - ctx->allocate_persistent(t.dtype(), t.shape(), - &dev_tensors_.at(i), nullptr)); - const auto dTensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), dTensor->TotalBytes()); - void* devAddr = nullptr; - GET_TENSOR_ADDRESS(dTensor, devAddr); - device_buffers_.emplace( - input_names_.at(i), - std::pair(devAddr, dTensor->TotalBytes())); - } - calibRes->calibrator = - new TRTInt8Calibrator(device_buffers_, batchSize, repo_name); - string label(repo_name); - calibRes->thr = new std::thread([calibRes, label]() { - VLOG(0) << "Starting calibration thread, Calibration Resource @ " - << calibRes; - calibRes->builder->setInt8Calibrator(calibRes->calibrator); - calibRes->builder->setInt8Mode(true); - calibRes->engine = calibRes->builder->buildCudaEngine( - *calibRes->network); // will loop until we terminate calibrator - VLOG(0) << "SAMI Calibration loop terminated " << label; - }); - VLOG(0) << "SAMI initialized calibrator resource"; - } + // TODO(aaroey): make sure ctx->resource_mgr() is used in future PR. + auto trt_rm = tensorflow::tensorrt::TRTResourceManager::instance(); + auto res_mgr = trt_rm->getManager("TRTCalibOps"); + tensorflow::tensorrt::TRTCalibrationResource* calib_res = nullptr; + auto status = res_mgr->Lookup(resource_name_, resource_name_, &calib_res); - std::unordered_map input_data; - for (int i = 0; i < numInputs; i++) { - const Tensor& t = ctx->input(i); - void* data_address = nullptr; - const Tensor* t_ptr = &t; - GET_TENSOR_ADDRESS(t_ptr, data_address); - const auto dTensor = dev_tensors_.at(i).AccessTensor(ctx); - CHECK_EQ(t.TotalBytes(), - dTensor->TotalBytes()); // use the tensor so FW keeps it - if (VLOG_IS_ON(1)) { - void* devAddr = nullptr; - GET_TENSOR_ADDRESS(dTensor, devAddr); - if (devAddr != device_buffers_.at(input_names_.at(i)).first) { - LOG(WARNING) << "Device address is different!"; - } - } - input_data.emplace(input_names_.at(i), data_address); - ctx->set_output(i, t); - } - VLOG(2) << "Filled map for sending"; - calibRes->calibrator->setBatch(input_data); - VLOG(2) << "Passed calibration data"; - } else { + if (!status.ok()) { ctx->SetStatus(status); return; } + int num_inputs = ctx->num_inputs(); + // first run instantiate calibrator + if (calib_res->calibrator_ == nullptr) { + dev_tensors_.resize(num_inputs); + int batch_size = ctx->input(0).dim_size(0); + VLOG(1) << " Constructing calibrator"; + for (int i = 0; i < num_inputs; i++) { + // allocate workspace on device for inputs + const tensorflow::Tensor& t = ctx->input(i); + OP_REQUIRES_OK(ctx, + ctx->allocate_persistent(t.dtype(), t.shape(), + &dev_tensors_.at(i), nullptr)); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), device_tensor->TotalBytes()); + void* device_address = GetTensorAddress(device_tensor); + device_buffers_.emplace(input_names_.at(i), + std::pair( + device_address, device_tensor->TotalBytes())); + } + + calib_res->calibrator_ = + new TRTInt8Calibrator(device_buffers_, batch_size, resource_name_); + string label(resource_name_); + calib_res->thr_ = new std::thread([calib_res, label]() { + VLOG(1) << "Starting calibration thread, Calibration Resource @ " + << calib_res; + calib_res->builder_->setInt8Calibrator(calib_res->calibrator_); + calib_res->builder_->setInt8Mode(true); + calib_res->engine_ = calib_res->builder_->buildCudaEngine( + *calib_res->network_); // will loop until we terminate calibrator + VLOG(1) << "Calibration loop terminated " << label; + }); + VLOG(1) << "initialized calibrator resource"; + } // calibrator initialized + + // Pass input data to calibrator + std::unordered_map input_data; + for (int i = 0; i < num_inputs; i++) { + const Tensor& t = ctx->input(i); + void* data_address = GetTensorAddress(&t); + const auto device_tensor = dev_tensors_.at(i).AccessTensor(ctx); + CHECK_EQ(t.TotalBytes(), + device_tensor->TotalBytes()); // use the tensor so FW keeps it + input_data.emplace(input_names_.at(i), data_address); + ctx->set_output(i, t); + } + VLOG(2) << "Filled map for sending"; + calib_res->calibrator_->setBatch(input_data); + VLOG(2) << "Passed calibration data"; + // TODO(aaroey): make sure we wait for the completion of calibration on the + // last batch in future PR. }; #undef TYPECASE -#undef GET_TENSOR_ADDRESS REGISTER_KERNEL_BUILDER(Name("TRTCalibOp").Device(DEVICE_GPU), TRTCalibOp); -} // namespace trt +} // namespace tensorrt } // namespace tensorflow #endif #endif diff --git a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h index 74232235826865c1b1fc500d942add5ac558b49f..23df9db32f077a080eaff7479fcbe90d6a504c42 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_calib_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CONTRIB_TENSORRT_TRT_CALIB_OP_H -#define TENSORFLOW_CONTRIB_TENSORRT_TRT_CALIB_OP_H +#ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H +#define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H #include #include @@ -24,10 +24,12 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" + #if GOOGLE_CUDA #if GOOGLE_TENSORRT namespace tensorflow { -namespace trt { +namespace tensorrt { // TODO(sami): Convert this to async kernel! class TRTCalibOp : public OpKernel { public: @@ -36,15 +38,15 @@ class TRTCalibOp : public OpKernel { void Compute(OpKernelContext* context) override; private: - std::string repo_name; - std::vector segment_nodes_; - std::vector input_names_; + string resource_name_; + std::vector segment_nodes_; + std::vector input_names_; std::vector shapes_; - std::unordered_map> device_buffers_; + std::unordered_map> device_buffers_; std::vector dev_tensors_; }; -} // namespace trt +} // namespace tensorrt } // namespace tensorflow #endif #endif -#endif // TENSORFLOW_CONTRIB_TENSORRT_TRT_CALIB_OP_H +#endif // TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_CALIB_OP_H diff --git a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc index dab5a3e0e844a84fab4047cd993bab5e4fa9124e..4835e5065068ec7a59995eb7f6126b31aecf6704 100644 --- a/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc +++ b/tensorflow/contrib/tensorrt/ops/trt_calib_op.cc @@ -17,17 +17,18 @@ limitations under the License. #include "tensorflow/core/framework/shape_inference.h" namespace tensorflow { - REGISTER_OP("TRTCalibOp") - .Attr("segment_nodes: list(string)") // names of the ops in segment - .Attr("segment_output_names: list(string)") // names of the output ops in segment - .Attr("input_names: list(string)") // names of the inputs for passing into tensorrt + .Attr("segment_nodes: list(string)") // names of the ops in segment + .Attr("segment_output_names: list(string)") // names of the output ops in + // segment + .Attr("input_names: list(string)") // names of the inputs for + // passing into tensorrt .Attr("resource_name: string") .Attr("InT: list({int8, float16, float32})") .Input("in_tensor: InT") .Output("out_tensor: InT") .SetShapeFn([](tensorflow::shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_inputs(); i++){ + for (int i = 0; i < c->num_inputs(); i++) { c->set_output(i, c->input(i)); } return Status::OK(); diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d5cc76c4256bea70e75ea3dd9b1e87c951a9000 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" + +#include +#include +#include + +#include "tensorflow/core/platform/logging.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "cuda_runtime_api.h" + +namespace tensorflow { +namespace tensorrt { + +// set the batch size before constructing the thread to execute engine +int TRTInt8Calibrator::getBatchSize() const { return batch_size_; } + +TRTInt8Calibrator::TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name) + : batch_size_(batch_size), + done_(false), + dev_buffers_(dev_buffers), + calib_running_(false), + engine_name_(engine_name) {} + +bool TRTInt8Calibrator::setBatch( + const std::unordered_map& data) { + // TODO(aaroey): make sure that in future PR: + // 1. the mutex_lock is outside of the loop + // 2. wait() is used instead of wait_for() + // 3. done_ is to be protected by the mutex + // 4. the first batch is not missed + if (done_) return false; + while (calib_running_.load( + std::memory_order_acquire)) { // wait while calibration is running + tensorflow::mutex_lock l(cond_mtx_); + cond_.wait_for(l, std::chrono::milliseconds(50)); + if (done_) return false; + } + VLOG(1) << "Set Batch Waiting finished"; + for (const auto it : data) { + auto devptr = dev_buffers_.find(it.first); + if (devptr == dev_buffers_.end()) { + LOG(FATAL) << "FATAL " << engine_name_ << " input name '" << it.first + << "' does not match with the buffer names"; + } + const auto& d = devptr->second; + + // TODO(aaroey): we should not use sync copy on default stream. Make sure + // stream->ThenMemcpy() is used in future PRs. + auto status = + cudaMemcpy(d.first, it.second, d.second, cudaMemcpyDeviceToDevice); + if (status != cudaSuccess) { + LOG(FATAL) << "cudaMemcpy " << engine_name_ << " for '" << it.first + << "' failed with " << status; + } + } + calib_running_.store(true, std::memory_order_release); // release builder + cond_.notify_all(); + return true; +} + +bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, + int num_bindings) { + calib_running_.store(false, std::memory_order_release); // wait for new batch + cond_.notify_all(); + while (!calib_running_.load( + std::memory_order_acquire)) { // wait until new batch arrives + tensorflow::mutex_lock l(cond_mtx_); + cond_.wait_for(l, std::chrono::milliseconds(50)); + if (done_) return false; + } + if (done_) { + return false; + } + + for (int i = 0; i < num_bindings; i++) { + auto it = dev_buffers_.find(names[i]); + if (it == dev_buffers_.end()) { + LOG(FATAL) << "Calibration engine asked for unknown tensor name '" + << names[i] << "' at position " << i; + } + + bindings[i] = it->second.first; + } + return true; +} + +const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { + return nullptr; +} + +void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, + std::size_t length) {} +TRTInt8Calibrator::~TRTInt8Calibrator() { + VLOG(1) << "Destroying calibrator for " << engine_name_; +} + +} // namespace tensorrt +} // namespace tensorflow +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h new file mode 100644 index 0000000000000000000000000000000000000000..8830f7efe75b42eb82cffe5b07ddd3832b36145c --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ + +#include +#include +#include +#include +#include "tensorflow/core/platform/mutex.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorrt/include/NvInfer.h" +namespace tensorflow { +namespace tensorrt { +// This class provides a 1 element queue to match TFs push model to +// TRTs pull model for calibration. When TRT implements a means for +// a push calibration This class should be updated accordingly + +struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { + public: + TRTInt8Calibrator( + const std::unordered_map>& dev_buffers, + int batch_size, string engine_name); + int getBatchSize() const override; + bool getBatch(void* bindings[], const char* names[], + int num_bindings) override; + bool setBatch(const std::unordered_map& data); + void setDone() { done_ = true; } + const void* readCalibrationCache(std::size_t& length) override; + void writeCalibrationCache(const void* ptr, std::size_t length) override; + ~TRTInt8Calibrator(); + + private: + const int batch_size_; + tensorflow::mutex cond_mtx_; // mutex for condition_variable + tensorflow::condition_variable cond_; // condition variable to implement + // producer-consumer queue for + // calibration + bool done_; + const std::unordered_map> + dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with + // buffer names + std::atomic_bool calib_running_; + string engine_name_; +}; +} // namespace tensorrt +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc new file mode 100644 index 0000000000000000000000000000000000000000..e663eed4dd6704e2f41bde1dfabd411e86669ecd --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.cc @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/tensorrt/resources/trt_resource_manager.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace tensorrt { + +std::shared_ptr +tensorflow::tensorrt::TRTResourceManager::getManager(const string& op_name) { + // mutex is held for lookup only. Most instantiations where mutex will be held + // longer will be during op creation and should be ok. + tensorflow::mutex_lock lock(map_mutex_); + auto s = managers_.find(op_name); + if (s == managers_.end()) { + auto it = managers_.emplace( + op_name, std::make_shared(op_name)); + VLOG(1) << "Returning a new manager " << op_name; + return it.first->second; + } + VLOG(1) << "Returning old manager " << op_name; + return s->second; +} + +} // namespace tensorrt +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h new file mode 100644 index 0000000000000000000000000000000000000000..5f8ad491d3c13e8911b0b95c3e95e19afe4d59c0 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resource_manager.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_RESOURCE_MANAGER_H_ +#include + +#include +#include +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { +namespace tensorrt { + +class TRTResourceManager { + TRTResourceManager() = default; + + public: + static std::shared_ptr instance() { + static std::shared_ptr instance_( + new TRTResourceManager); + return instance_; + } + // returns a manager for given op, if it doesn't exists it creates one + std::shared_ptr getManager(const string& op_name); + + private: + std::unordered_map> + managers_; + tensorflow::mutex map_mutex_; +}; + +} // namespace tensorrt +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCE_TRT_RESOURCE_MANAGER_H_ diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h new file mode 100644 index 0000000000000000000000000000000000000000..3c85968ae7acf5c5fc567be6805a5d226b1094c7 --- /dev/null +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -0,0 +1,95 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_ +#define TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRTRESOURCES_H_ + +#include +#include +#include +#include +#include +#include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/core/framework/resource_mgr.h" + +#if GOOGLE_CUDA +#if GOOGLE_TENSORRT +#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h" +#include "tensorrt/include/NvInfer.h" + +namespace tensorflow { +namespace tensorrt { +class TRTCalibrationResource : public tensorflow::ResourceBase { + public: + TRTCalibrationResource() + : calibrator_(nullptr), + builder_(nullptr), + network_(nullptr), + engine_(nullptr), + logger_(nullptr), + thr_(nullptr) {} + string DebugString() override { + std::stringstream oss; + oss << " Calibrator = " << std::hex << calibrator_ << std::dec << std::endl + << " Builder = " << std::hex << builder_ << std::dec << std::endl + << " Network = " << std::hex << network_ << std::dec << std::endl + << " Engine = " << std::hex << engine_ << std::dec << std::endl + << " Logger = " << std::hex << logger_ << std::dec << std::endl + << " Thread = " << std::hex << thr_ << std::dec << std::endl; + return oss.str(); + } + ~TRTCalibrationResource() { + VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); + } + TRTInt8Calibrator* calibrator_; + nvinfer1::IBuilder* builder_; + nvinfer1::INetworkDefinition* network_; + nvinfer1::ICudaEngine* engine_; + tensorflow::tensorrt::Logger* logger_; + // TODO(sami): Use threadpool threads! + std::thread* thr_; +}; + +class TRTWeightStore : public tensorflow::ResourceBase { + public: + TRTWeightStore() {} + std::list> store_; + string DebugString() override { + std::stringstream oss; + size_t lenBytes = 0; + for (const auto& v : store_) { + lenBytes += v.size() * sizeof(uint8_t); + } + oss << " Number of entries = " << store_.size() << std::endl + << " Total number of bytes = " + << store_.size() * sizeof(std::vector) + lenBytes << std::endl; + return oss.str(); + } + virtual ~TRTWeightStore() { VLOG(1) << "Destroying store" << DebugString(); } +}; + +class TRTEngineResource : public tensorflow::ResourceBase { + public: + TRTEngineResource() : runtime_(nullptr), ctx_(nullptr){}; + string DebugString() override { return string(""); } + nvinfer1::IRuntime* runtime_; + nvinfer1::IExecutionContext* ctx_; +}; + +} // namespace tensorrt +} // namespace tensorflow +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCEMGR_TRTRESOURCES_H_ +#endif +#endif diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py index 9e4077eca0eafa5405bd217006a48cf946bd9bed..cfa18ab1874f74079acef1f7d670febca325730d 100644 --- a/tensorflow/contrib/tensorrt/test/test_tftrt.py +++ b/tensorflow/contrib/tensorrt/test/test_tftrt.py @@ -37,7 +37,7 @@ from tensorflow.python.ops import nn_ops as nn_ops def get_simple_graph_def(): - """Create a simple graph and return its graph_def""" + """Create a simple graph and return its graph_def.""" g = ops.Graph() with g.as_default(): a = aops.placeholder( diff --git a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv index b49a0662c29b1d810f4be31ca1f318f0571f533e..9b15b4f0b26f11ac3281ca4206654872984628b6 100644 --- a/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv +++ b/tensorflow/contrib/timeseries/examples/data/multivariate_periods.csv @@ -1,100 +1,100 @@ -0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0. -1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0. -2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0. -3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0. -4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0. -5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0. -6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0. -7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0. -8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0. -9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0. -10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0. -11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0. -12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0. -13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0. -14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0. -15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0. -16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0. -17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0. -18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0. -19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0. -20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0. -21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0. -22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0. -23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0. -24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0. -25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0. -26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0. -27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0. -28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0. -29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0. -30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0. -31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0. -32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0. -33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0. -34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0. -35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0. -36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0. -37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0. -38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0. -39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0. -40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0. -41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0. -42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0. -43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0. -44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0. -45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0. -46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0. -47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0. -48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0. -49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0. -50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0. -51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0. -52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0. -53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0. -54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0. -55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0. -56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0. -57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0. -58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0. -59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0. -60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0. -61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0. -62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0. -63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0. -64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0. -65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0. -66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0. -67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0. -68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0. -69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0. -70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0. -71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0. -72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0. -73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0. -74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0. -75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0. -76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0. -77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0. -78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0. -79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0. -80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0. -81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0. -82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0. -83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0. -84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0. -85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0. -86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0. -87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0. -88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0. -89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0. -90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0. -91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0. -92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0. -93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0. -94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0. -95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0. -96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0. -97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0. -98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0. -99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0. +0,0.926906299771,1.99107237682,2.56546245685,3.07914768197,4.04839057867,1.,0.,strkeya +1,0.108010001864,1.41645361423,2.1686839775,2.94963962176,4.1263503303,1.,0.,strkeyb +2,-0.800567600028,1.0172132907,1.96434754116,2.99885333086,4.04300485864,1.,0.,strkey +3,0.0607042871898,0.719540073421,1.9765012584,2.89265588817,4.0951014426,1.,0.,strkey +4,0.933712200629,0.28052120776,1.41018552514,2.69232603996,4.06481164223,1.,0.,strkey +5,-0.171730652974,0.260054421028,1.48770816369,2.62199129293,4.44572807842,1.,0.,strkey +6,-1.00180162933,0.333045158863,1.50006392277,2.88888309683,4.24755865606,1.,0.,strkey +7,0.0580061875336,0.688929398826,1.56543458772,2.99840358953,4.52726873347,1.,0.,strkey +8,0.764139447412,1.24704875327,1.77649279698,3.13578593851,4.63238922951,1.,0.,strkey +9,-0.230331874785,1.47903998963,2.03547545751,3.20624030377,4.77980005228,1.,0.,strkey +10,-1.03846045211,2.01133000781,2.31977503972,3.67951536251,5.09716775897,1.,0.,strkeyc +11,0.188643592253,2.23285349038,2.68338482249,3.49817168611,5.24928239634,1.,0.,strkey +12,0.91207302309,2.24244446841,2.71362604985,3.96332587625,5.37802271594,1.,0.,strkey +13,-0.296588665881,2.02594634141,3.07733910479,3.99698324956,5.56365901394,1.,0.,strkey +14,-0.959961476551,1.45078629833,3.18996420137,4.3763059609,5.65356015609,1.,0.,strkey +15,0.46313530679,1.01141441548,3.4980215948,4.20224896882,5.88842247449,1.,0.,strkey +16,0.929354125798,0.626635305936,3.70508262244,4.51791573544,5.73945973251,1.,0.,strkey +17,-0.519110731957,0.269249223148,3.39866823332,4.46802003061,5.82768174382,1.,0.,strkey +18,-0.924330981367,0.349602834684,3.21762413294,4.72803587499,5.94918925767,1.,0.,strkey +19,0.253239387885,0.345158023497,3.11071425333,4.79311566935,5.9489259713,1.,0.,strkey +20,0.637408390225,0.698996675371,3.25232492145,4.73814732384,5.9612010251,1.,0.,strkey +21,-0.407396859412,1.17456342803,2.49526823723,4.59323415742,5.82501686811,1.,0.,strkey +22,-0.967485452118,1.66655933642,2.47284606244,4.58316034754,5.88721406681,1.,0.,strkey +23,0.474480867904,1.95018556323,2.0228950072,4.48651142819,5.8255943735,1.,0.,strkey +24,1.04309652155,2.23519892356,1.91924131572,4.19094661783,5.87457348436,1.,0.,strkey +25,-0.517861513772,2.12501967336,1.70266619979,4.05280882887,5.72160912899,1.,0.,strkey +26,-0.945301585146,1.65464653549,1.81567174251,3.92309850635,5.58270493814,1.,0.,strkey +27,0.501153868974,1.40600764889,1.53991387719,3.72853247942,5.60169001727,1.,0.,strkey +28,0.972859524418,1.00344321868,1.5175642828,3.64092376655,5.10567722582,1.,0.,strkey +29,-0.70553406135,0.465306263885,1.7038540803,3.33236870312,5.09182481555,1.,0.,strkey +30,-0.946093634916,0.294539309453,1.88052827037,2.93011492669,4.97354922696,1.,0.,strkey +31,0.47922123231,0.308465865031,2.03445883031,2.90772899045,4.86241793548,1.,0.,strkey +32,0.754030014252,0.549752241167,2.46115815089,2.95063349534,4.71834614627,1.,0.,strkey +33,-0.64875949826,0.894615488148,2.5922463381,2.81269864022,4.43480095104,1.,0.,strkey +34,-0.757829951086,1.39123914261,2.69258079904,2.61834837315,4.36580046156,1.,0.,strkey +35,0.565653301088,1.72360022693,2.97794913834,2.80403840334,4.27327248459,1.,0.,strkey +36,0.867440092372,2.21100730052,3.38648090792,2.84057515729,4.12210169576,1.,0.,strkey +37,-0.894567758095,2.17549105818,3.45532493329,2.90446025717,4.00251740584,1.,0.,strkeyd +38,-0.715442356893,2.15105389965,3.52041791902,3.03650393392,4.12809249577,1.,0.,strkey +39,0.80671703672,1.81504564517,3.60463324866,3.00747789871,3.98440762467,1.,0.,strkey +40,0.527014790142,1.31803513865,3.43842186337,3.3332594663,4.03232406566,1.,0.,strkey +41,-0.795936862129,0.847809114454,3.09875133548,3.52863155938,3.94883924909,1.,0.,strkey +42,-0.610245806946,0.425530441018,2.92581949152,3.77238736123,4.27287245021,1.,0.,strkey +43,0.611662279431,0.178432049837,2.48128214822,3.73212087883,4.17319013831,1.,0.,strkey +44,0.650866553108,0.220341648392,2.41694642022,4.2609098519,4.27271645905,1.,0.,strkey +45,-0.774156982023,0.632667602331,2.05474356052,4.32889204886,4.18029723271,1.,0.,strkey +46,-0.714058448409,0.924562377599,1.75706135146,4.52492718422,4.3972678094,1.,0.,strkey +47,0.889627293379,1.46207968841,1.78299357672,4.64466731095,4.56317887554,1.,0.,strkey +48,0.520140662861,1.8996333843,1.41377633823,4.48899091177,4.78805049769,1.,0.,strkey +49,-1.03816935616,2.08997002059,1.51218375351,4.84167764204,4.93026048606,1.,0.,strkey +50,-0.40772951362,2.30878972136,1.44144415128,4.76854460997,5.01538444629,1.,0.,strkey +51,0.792730684781,1.91367048509,1.58887384677,4.71739397335,5.25690012199,1.,0.,strkey +52,0.371311881576,1.67565079528,1.81688563053,4.60353107555,5.44265822961,1.,0.,strkey +53,-0.814398070371,1.13374634126,1.80328814859,4.72264252878,5.52674761122,1.,0.,strkey +54,-0.469017949323,0.601244136627,2.29690896736,4.49859178859,5.54126153454,1.,0.,strkey +55,0.871044371426,0.407597593794,2.7499112487,4.19060637761,5.57693767301,1.,0.,strkey +56,0.523764933017,0.247705192709,3.09002071379,4.02095509006,5.80510362182,1.,0.,strkey +57,-0.881326403531,0.31513103164,3.11358205718,3.96079100808,5.81000652365,1.,0.,strkey +58,-0.357928025339,0.486163915865,3.17884556771,3.72634990659,5.85693642011,1.,0.,strkey +59,0.853038779822,1.04218094475,3.45835384454,3.36703969978,5.9585988449,1.,0.,strkey +60,0.435311516013,1.59715085283,3.63313338588,3.11276729421,5.93643818229,1.,0.,strkey +61,-1.02703719138,1.92205832542,3.47606111735,3.06247155999,6.02106646259,1.,0.,strkey +62,-0.246661325557,2.14653802542,3.29446326567,2.89936259181,5.67531541272,1.,0.,strkey +63,1.02554736569,2.25943737733,3.07031591528,2.78176218013,5.78206328989,1.,0.,strkey +64,0.337814475969,2.07589147224,2.80356226089,2.55888206331,5.7094075496,1.,0.,strkey +65,-1.12023369929,1.25333011618,2.56497288445,2.77361359194,5.50799418376,1.,0.,strkey +66,-0.178980246554,1.11937139901,2.51598681313,2.91438309151,5.47469577206,1.,0.,strkey +67,0.97550951531,0.60553823137,2.11657741073,2.88081098981,5.37034999502,1.,0.,strkey +68,0.136653357206,0.365828836075,1.97386033165,3.13217903204,5.07254490219,1.,0.,strkey +69,-1.05607596951,0.153152115069,1.52110743825,3.01308794192,5.08902539125,1.,0.,strkey +70,-0.13095280331,0.337113974483,1.52703079853,3.16687131599,4.86649398514,1.,0.,strkey +71,1.07081057754,0.714247566736,1.53761382634,3.45151989484,4.75892309166,1.,0.,strkey +72,0.0153410376082,1.24631231847,1.61690939161,3.85481994498,4.35683752832,1.,0.,strkey +73,-0.912801257303,1.60791309476,1.8729264524,4.03037260012,4.36072588913,1.,0.,strkey +74,-0.0894895640338,2.02535207407,1.93484909619,4.09557485132,4.35327025188,1.,0.,strkey +75,0.978646999652,2.20085086625,2.09003440427,4.27542353033,4.1805058388,1.,0.,strkey +76,-0.113312642876,2.2444100761,2.50789248839,4.4151861502,4.03267168136,1.,0.,strkey +77,-1.00215099149,1.84305628445,2.61691237246,4.45425147595,3.81203553766,1.,0.,strkey +78,-0.0183234614205,1.49573923116,2.99308471214,4.71134960112,4.0273804959,1.,0.,strkey +79,1.0823738177,1.12211589848,3.27079386925,4.94288270502,4.01851068083,1.,0.,strkey +80,0.124370187893,0.616474412808,3.4284236674,4.76942168327,3.9749536483,1.,0.,strkey +81,-0.929423379352,0.290977090976,3.34131726136,4.78590392707,4.10190661656,1.,0.,strkey +82,0.23766302648,0.155302052254,3.49779513794,4.64605656795,4.15571321107,1.,0.,strkey +83,1.03531486192,0.359702776204,3.4880725919,4.48167586667,4.21134561991,1.,0.,strkey +84,-0.261234571382,0.713877760378,3.42756426614,4.426443869,4.25208300527,1.,0.,strkey +85,-1.03572442277,1.25001113691,2.96908341113,4.25500915322,4.25723010649,1.,0.,strkey +86,0.380034261243,1.70543355622,2.73605932518,4.16703432307,4.63700400788,1.,0.,strkey +87,1.03734873488,1.97544410562,2.55586572141,3.84976673263,4.55282864289,1.,0.,strkey +88,-0.177344253372,2.22614526325,2.09565864891,3.77378097953,4.82577400298,1.,0.,strkey +89,-0.976821526892,2.18385079177,1.78522284118,3.67768223554,5.06302440873,1.,0.,strkey +90,0.264820472091,1.86981946157,1.50048403865,3.43619796921,5.05651761669,1.,0.,strkey +91,1.05642344868,1.47568646076,1.51347671977,3.20898518885,5.50149047462,1.,0.,strkey +92,-0.311607433358,1.04226467636,1.52089650905,3.02291865417,5.4889046232,1.,0.,strkey +93,-0.724285777937,0.553052311957,1.48573560173,2.7365973598,5.72549174225,1.,0.,strkey +94,0.519859192905,0.226520626591,1.61543723167,2.84102086852,5.69330622288,1.,0.,strkey +95,1.0323195039,0.260873217055,1.81913034804,2.83951143848,5.90325028086,1.,0.,strkey +96,-0.53285682538,0.387695521405,1.70935609313,2.57977050631,5.79579213161,1.,0.,strkey +97,-0.975127997215,0.920948771589,2.51292643636,2.71004616612,5.87016469227,1.,0.,strkey +98,0.540246804099,1.36445470181,2.61949412896,2.98482553485,6.02447664937,1.,0.,strkey +99,0.987764008058,1.85581989607,2.84685706149,2.94760204892,6.0212151724,1.,0.,strkey diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py index 7659dd308a7ee1b70d6688b85e4f6157ddee0540..c08c0b0acb917f527d7efa91874d6405b9220083 100644 --- a/tensorflow/contrib/timeseries/examples/known_anomaly.py +++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py @@ -46,12 +46,12 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300): # Indicate the format of our exogenous feature, in this case a string # representing a boolean value. - string_feature = tf.contrib.layers.sparse_column_with_keys( - column_name="is_changepoint", keys=["no", "yes"]) + string_feature = tf.feature_column.categorical_column_with_vocabulary_list( + key="is_changepoint", vocabulary_list=["no", "yes"]) # Specify the way this feature is presented to the model, here using a one-hot # encoding. - one_hot_feature = tf.contrib.layers.one_hot_column( - sparse_id_column=string_feature) + one_hot_feature = tf.feature_column.indicator_column( + categorical_column=string_feature) estimator = tf.contrib.timeseries.StructuralEnsembleRegressor( periodicities=12, diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index f37cafcc502dc9415db0829b9b067b862f87dca7..2eee878196bb64b523c491ca808ca8d6ff5dd36c 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -59,10 +59,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): num_units: The number of units in the model's LSTMCell. num_features: The dimensionality of the time series (features per timestep). - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects representing features which are inputs to the model but are - not predicted by it. These must then be present for training, - evaluation, and prediction. + exogenous_feature_columns: A list of `tf.feature_column`s representing + features which are inputs to the model but are not predicted by + it. These must then be present for training, evaluation, and + prediction. dtype: The floating point data type to use. """ super(_LSTMModel, self).__init__( @@ -189,12 +189,16 @@ def train_and_predict( export_directory=None): """Train and predict using a custom time series model.""" # Construct an Estimator from our LSTM model. + categorical_column = tf.feature_column.categorical_column_with_hash_bucket( + key="categorical_exogenous_feature", hash_bucket_size=16) exogenous_feature_columns = [ # Exogenous features are not part of the loss, but can inform # predictions. In this example the features have no extra information, but # are included as an API example. - tf.contrib.layers.real_valued_column( - "2d_exogenous_feature", dimension=2)] + tf.feature_column.numeric_column( + "2d_exogenous_feature", shape=(2,)), + tf.feature_column.embedding_column( + categorical_column=categorical_column, dimension=10)] estimator = ts_estimators.TimeSeriesRegressor( model=_LSTMModel(num_features=5, num_units=128, exogenous_feature_columns=exogenous_feature_columns), @@ -205,7 +209,11 @@ def train_and_predict( csv_file_name, column_names=((tf.contrib.timeseries.TrainEvalFeatures.TIMES,) + (tf.contrib.timeseries.TrainEvalFeatures.VALUES,) * 5 - + ("2d_exogenous_feature",) * 2)) + + ("2d_exogenous_feature",) * 2 + + ("categorical_exogenous_feature",)), + # Data types other than for `times` need to be specified if they aren't + # float32. In this case one of our exogenous features has string dtype. + column_dtypes=((tf.int64,) + (tf.float32,) * 7 + (tf.string,))) train_input_fn = tf.contrib.timeseries.RandomWindowInputFn( reader, batch_size=4, window_size=32) estimator.train(input_fn=train_input_fn, steps=training_steps) @@ -215,7 +223,9 @@ def train_and_predict( predict_exogenous_features = { "2d_exogenous_feature": numpy.concatenate( [numpy.ones([1, 100, 1]), numpy.zeros([1, 100, 1])], - axis=-1)} + axis=-1), + "categorical_exogenous_feature": numpy.array( + ["strkey"] * 100)[None, :, None]} (predictions,) = tuple(estimator.predict( input_fn=tf.contrib.timeseries.predict_continuation_input_fn( evaluation, steps=100, diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index fff972c1f3277ad5d83673a202a50d1e6f7df210..ed3ed4c0e1731df62e9197aa7471fd6a31e9858e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -140,11 +140,13 @@ py_library( "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:state_ops", + "//tensorflow/python:summary", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:export", "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:metric_keys", ], ) diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index f8355f366fe8e191ab570fd271bbe4a8bf71c73d..8d13343e82340dae11b0be54e3bc3152060dca36 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.layers.python.layers import feature_column - from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib @@ -31,10 +29,12 @@ from tensorflow.contrib.timeseries.python.timeseries.state_space_models.filterin from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.export import export_lib +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.training import training as train @@ -117,22 +117,29 @@ class TimeSeriesRegressor(estimator_lib.Estimator): dtype=self._model.dtype), shape=(default_batch_size, default_series_length, self._model.num_features))) - with ops.Graph().as_default(): - # Default placeholders have only an unknown batch dimension. Make them - # in a separate graph, then splice in the series length to the shapes - # and re-create them in the outer graph. - exogenous_feature_shapes = { - key: (value.get_shape(), value.dtype) for key, value - in feature_column.make_place_holder_tensors_for_base_features( - self._model.exogenous_feature_columns).items()} - for feature_key, (batch_only_feature_shape, value_dtype) in ( - exogenous_feature_shapes.items()): - batch_only_feature_shape = batch_only_feature_shape.with_rank_at_least( - 1).as_list() - feature_shape = ([default_batch_size, default_series_length] - + batch_only_feature_shape[1:]) - placeholders[feature_key] = array_ops.placeholder( - dtype=value_dtype, name=feature_key, shape=feature_shape) + if self._model.exogenous_feature_columns: + with ops.Graph().as_default(): + # Default placeholders have only an unknown batch dimension. Make them + # in a separate graph, then splice in the series length to the shapes + # and re-create them in the outer graph. + parsed_features = ( + feature_column.make_parse_example_spec( + self._model.exogenous_feature_columns)) + placeholder_features = parsing_ops.parse_example( + serialized=array_ops.placeholder( + shape=[None], dtype=dtypes.string), + features=parsed_features) + exogenous_feature_shapes = { + key: (value.get_shape(), value.dtype) for key, value + in placeholder_features.items()} + for feature_key, (batch_only_feature_shape, value_dtype) in ( + exogenous_feature_shapes.items()): + batch_only_feature_shape = ( + batch_only_feature_shape.with_rank_at_least(1).as_list()) + feature_shape = ([default_batch_size, default_series_length] + + batch_only_feature_shape[1:]) + placeholders[feature_key] = array_ops.placeholder( + dtype=value_dtype, name=feature_key, shape=feature_shape) # Models may not know the shape of their state without creating some # variables/ops. Avoid polluting the default graph by making a new one. We # use only static metadata from the returned Tensors. @@ -333,11 +340,11 @@ class StructuralEnsembleRegressor(StateSpaceRegressor): determine the model size. Learning autoregressive coefficients typically requires more steps and a smaller step size than other components. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments, `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]), and diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py index f0330bfbbd6e8067e5d085376acdf2e6bcaccb6a..5c49e903abde6d7487d1ffdb83ff902ff6b63585 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/head.py +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -26,6 +26,7 @@ from tensorflow.contrib.timeseries.python.timeseries import feature_keys from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.export import export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -35,6 +36,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest +from tensorflow.python.summary import summary def time_series_regression_head(model, @@ -71,12 +73,31 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc self.input_statistics_generator = input_statistics_generator self._name = name + @property + def name(self): + return self._name + + # TODO(terrytangyuan): consolidate `model_outputs` and `_Head.LossSpec` + # once `_Head.create_loss` becomes extendable + def create_loss(self, features, mode, logits=None, labels=None): + """See `_Head`.""" + model_outputs = self.state_manager.define_loss( + self.model, features, mode) + summary.scalar( + head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), + model_outputs.loss) + return model_outputs + + @property + def logits_dimension(self): + """See `_Head`.""" + return 1 + def _train_ops(self, features): """Add training ops to the graph.""" + mode = estimator_lib.ModeKeys.TRAIN with variable_scope.variable_scope("model"): - model_outputs = self.state_manager.define_loss( - self.model, features, estimator_lib.ModeKeys.TRAIN) - + model_outputs = self.create_loss(features, mode) train_op = optimizers.optimize_loss( model_outputs.loss, global_step=training_util.get_global_step(), @@ -85,31 +106,14 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc learning_rate=None) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.TRAIN, + mode=mode, train_op=train_op) - # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name` - @property - def name(self): - return self._name - - # TODO(terrytangyuan): unused for now. Need to decouple - # `state_manager.define_loss` to satisfy the extendable return signature of - # `_Head.create_loss`. - def create_loss(self, features, mode, logits, labels): - """See `_Head`.""" - return None - - # TODO(terrytangyuan): check label dimension - @property - def logits_dimension(self): - return None - def _evaluate_ops(self, features): """Add ops for evaluation (aka filtering) to the graph.""" + mode = estimator_lib.ModeKeys.EVAL with variable_scope.variable_scope("model"): - model_outputs = self.state_manager.define_loss( - self.model, features, estimator_lib.ModeKeys.EVAL) + model_outputs = self.create_loss(features, mode) metrics = {} # Just output in-sample predictions for the last chunk seen for prediction_key, prediction_value in model_outputs.predictions.items(): @@ -122,7 +126,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc model_outputs.end_state)) return estimator_lib.EstimatorSpec( loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.EVAL, + mode=mode, eval_metric_ops=metrics, predictions={}) @@ -140,9 +144,8 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc with variable_scope.variable_scope("model"): prediction_outputs = self.model.predict(features=features) with variable_scope.variable_scope("model", reuse=True): - filtering_outputs = self.state_manager.define_loss( - self.model, features, estimator_lib.ModeKeys.EVAL) - + filtering_outputs = self.create_loss( + features, estimator_lib.ModeKeys.EVAL) return estimator_lib.EstimatorSpec( mode=estimator_lib.ModeKeys.PREDICT, export_outputs={ @@ -191,7 +194,7 @@ class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-acc def create_estimator_spec(self, features, mode, labels=None): """Performs basic error checking and returns an EstimatorSpec.""" - with ops.name_scope("head"): + with ops.name_scope(self._name, "head"): if labels: raise ValueError( "The model received a `labels` dictionary, which is " diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py index d4ee59036624cffb216709e096981d362670e416..04225333b9377447f46d32663df76aece97a51e7 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline.py @@ -500,6 +500,41 @@ class CSVReader(ReaderBaseTimeSeriesParser): return features +class TFExampleReader(ReaderBaseTimeSeriesParser): + """Reads and parses `tf.Example`s from a TFRecords file.""" + + def __init__(self, + filenames, + features): + """Configure `tf.Example` parsing. + + Args: + filenames: A filename or list of filenames to read the time series + from. Each line must have columns corresponding to `column_names`. + features: A dictionary mapping from feature keys to `tf.FixedLenFeature` + objects. Must include `TrainEvalFeatures.TIMES` (scalar integer) and + `TrainEvalFeatures.VALUES` (floating point vector) features. + Raises: + ValueError: If required times/values features are not present. + """ + if feature_keys.TrainEvalFeatures.TIMES not in features: + raise ValueError("'{}' is a required column.".format( + feature_keys.TrainEvalFeatures.TIMES)) + if feature_keys.TrainEvalFeatures.VALUES not in features: + raise ValueError("'{}' is a required column.".format( + feature_keys.TrainEvalFeatures.VALUES)) + self._features = features + super(TFExampleReader, self).__init__(filenames=filenames) + + def _get_reader(self): + return io_ops.TFRecordReader() + + def _process_records(self, examples): + """Parse `tf.Example`s into `Tensors`.""" + return parsing_ops.parse_example( + serialized=examples, features=self._features) + + class TimeSeriesInputFn(object): """Base for classes which create batches of windows from a time series.""" diff --git a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py index ed78a835a4d451e9e7d18bb833d8ebed6c05a195..703537abf0fe3985aaf0434cc633cb410dd6bd4c 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/input_pipeline_test.py @@ -27,7 +27,11 @@ from tensorflow.contrib.timeseries.python.timeseries import input_pipeline from tensorflow.contrib.timeseries.python.timeseries import test_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures +from tensorflow.core.example import example_pb2 +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.lib.io import tf_record +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.training import coordinator as coordinator_lib @@ -52,6 +56,21 @@ def _make_csv_time_series(num_features, num_samples, test_tmpdir): return filename +def _make_tfexample_series(num_features, num_samples, test_tmpdir): + _, data_file = tempfile.mkstemp(dir=test_tmpdir) + with tf_record.TFRecordWriter(data_file) as writer: + for i in range(num_samples): + example = example_pb2.Example() + times = example.features.feature[TrainEvalFeatures.TIMES] + times.int64_list.value.append(i) + values = example.features.feature[TrainEvalFeatures.VALUES] + values.float_list.value.extend( + [float(i) * 2. + feature_number + for feature_number in range(num_features)]) + writer.write(example.SerializeToString()) + return data_file + + def _make_numpy_time_series(num_features, num_samples): times = numpy.arange(num_samples) values = times[:, None] * 2. + numpy.arange(num_features)[None, :] @@ -107,6 +126,19 @@ class RandomWindowInputFnTests(test.TestCase): time_series_reader = input_pipeline.CSVReader([filename]) self._test_out_of_order(time_series_reader, discard_out_of_order=False) + def test_tfexample_sort_out_of_order(self): + filename = _make_tfexample_series( + num_features=1, num_samples=50, + test_tmpdir=self.get_temp_dir()) + time_series_reader = input_pipeline.TFExampleReader( + [filename], + features={ + TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( + shape=[], dtype=dtypes.int64), + TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( + shape=[1], dtype=dtypes.float32)}) + self._test_out_of_order(time_series_reader, discard_out_of_order=False) + def test_numpy_sort_out_of_order(self): data = _make_numpy_time_series(num_features=1, num_samples=50) time_series_reader = input_pipeline.NumpyReader(data) @@ -183,6 +215,20 @@ class RandomWindowInputFnTests(test.TestCase): self._test_multivariate(time_series_reader=time_series_reader, num_features=2) + def test_tfexample_multivariate(self): + filename = _make_tfexample_series( + num_features=2, num_samples=50, + test_tmpdir=self.get_temp_dir()) + time_series_reader = input_pipeline.TFExampleReader( + [filename], + features={ + TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( + shape=[], dtype=dtypes.int64), + TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( + shape=[2], dtype=dtypes.float32)}) + self._test_multivariate(time_series_reader=time_series_reader, + num_features=2) + def test_numpy_multivariate(self): data = _make_numpy_time_series(num_features=3, num_samples=50) time_series_reader = input_pipeline.NumpyReader(data) @@ -248,6 +294,20 @@ class WholeDatasetInputFnTests(test.TestCase): self._whole_dataset_input_fn_test_template( time_series_reader=time_series_reader, num_features=1, num_samples=50) + def test_tfexample(self): + filename = _make_tfexample_series( + num_features=4, num_samples=100, + test_tmpdir=self.get_temp_dir()) + time_series_reader = input_pipeline.TFExampleReader( + [filename], + features={ + TrainEvalFeatures.TIMES: parsing_ops.FixedLenFeature( + shape=[], dtype=dtypes.int64), + TrainEvalFeatures.VALUES: parsing_ops.FixedLenFeature( + shape=[4], dtype=dtypes.float32)}) + self._whole_dataset_input_fn_test_template( + time_series_reader=time_series_reader, num_features=4, num_samples=100) + def test_numpy(self): data = _make_numpy_time_series(num_features=4, num_samples=100) time_series_reader = input_pipeline.NumpyReader(data) diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index bac7d1ebf59b28d4688a3d1a69ecdc1fc12248e0..7644764a7459db3951fe9a2790389713dd412a8f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -21,18 +21,17 @@ from __future__ import print_function import abc import collections -from tensorflow.contrib import layers -from tensorflow.contrib.layers import feature_column - from tensorflow.contrib.timeseries.python.timeseries import math_utils from tensorflow.contrib.timeseries.python.timeseries.feature_keys import PredictionFeatures from tensorflow.contrib.timeseries.python.timeseries.feature_keys import TrainEvalFeatures +from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope @@ -66,11 +65,11 @@ class TimeSeriesModel(object): Args: num_features: Number of features for the time series - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not + part of the series to be predicted. Passed to + `tf.feature_column.input_layer`. dtype: The floating point datatype to use. """ if exogenous_feature_columns: @@ -86,7 +85,7 @@ class TimeSeriesModel(object): @property def exogenous_feature_columns(self): - """`FeatureColumn` objects for features which are not predicted.""" + """`tf.feature_colum`s for features which are not predicted.""" return self._exogenous_feature_columns # TODO(allenl): Move more of the generic machinery for generating and @@ -265,11 +264,14 @@ class TimeSeriesModel(object): if not self._exogenous_feature_columns: return (0,) with ops.Graph().as_default(): - placeholder_features = ( - feature_column.make_place_holder_tensors_for_base_features( + parsed_features = ( + feature_column.make_parse_example_spec( self._exogenous_feature_columns)) - embedded = layers.input_from_feature_columns( - columns_to_tensors=placeholder_features, + placeholder_features = parsing_ops.parse_example( + serialized=array_ops.placeholder(shape=[None], dtype=dtypes.string), + features=parsed_features) + embedded = feature_column.input_layer( + features=placeholder_features, feature_columns=self._exogenous_feature_columns) return embedded.get_shape().as_list()[1:] @@ -308,13 +310,13 @@ class TimeSeriesModel(object): # Avoid shape warnings when embedding "scalar" exogenous features (those # with only batch and window dimensions); input_from_feature_columns # expects input ranks to match the embedded rank. - if tensor.get_shape().ndims == 1: + if tensor.get_shape().ndims == 1 and tensor.dtype != dtypes.string: exogenous_features_single_batch_dimension[name] = tensor[:, None] else: exogenous_features_single_batch_dimension[name] = tensor embedded_exogenous_features_single_batch_dimension = ( - layers.input_from_feature_columns( - columns_to_tensors=exogenous_features_single_batch_dimension, + feature_column.input_layer( + features=exogenous_features_single_batch_dimension, feature_columns=self._exogenous_feature_columns, trainable=True)) exogenous_regressors = array_ops.reshape( @@ -381,8 +383,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): may use _scale_back_data or _scale_back_variance to return predictions to the input scale. dtype: The floating point datatype to use. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects. See `TimeSeriesModel`. + exogenous_feature_columns: A list of `tf.feature_column`s objects. See + `TimeSeriesModel`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py index 6257002647ed53bbde3ace11a6b45e4e2cdeb57d..951c6546d5fed77e0cfa98a4e774b804639d7dad 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py @@ -112,11 +112,11 @@ class StateSpaceModelConfiguration( exogenous_noise_decreases: If True, exogenous regressors can "set" model state, decreasing uncertainty. If both this parameter and exogenous_noise_increases are False, exogenous regressors are ignored. - exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn - objects (for example tf.contrib.layers.embedding_column) corresponding - to exogenous features which provide extra information to the model but - are not part of the series to be predicted. Passed to - tf.contrib.layers.input_from_feature_columns. + exogenous_feature_columns: A list of `tf.feature_column`s (for example + `tf.feature_column.embedding_column`) corresponding to exogenous + features which provide extra information to the model but are not part + of the series to be predicted. Passed to + `tf.feature_column.input_layer`. exogenous_update_condition: A function taking two Tensor arguments `times` (shape [batch size]) and `features` (a dictionary mapping exogenous feature keys to Tensors with shapes [batch size, ...]) and returning a diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index c48e84ddfaac8ac9c07e061847315eab3fd72152..095b4821f10b32ff742711caa155e60beb624852 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -163,6 +163,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":datasets", ":profiler", ":tpu_py", "//tensorflow/contrib/tpu/proto:topology_proto_py", @@ -181,6 +182,33 @@ py_library( ], ) +py_library( + name = "datasets", + srcs = [ + "python/tpu/datasets.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +tf_py_test( + name = "datasets_test", + srcs = ["python/tpu/datasets_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + ":datasets", + ], + grpc_enabled = True, +) + tf_py_test( name = "tpu_test", size = "small", diff --git a/tensorflow/contrib/tpu/ops/infeed_ops.cc b/tensorflow/contrib/tpu/ops/infeed_ops.cc index 849c4a1102787870b372c35740cf0fe271efa078..efc546f9a6077de9cac5a5acefa3fc7206547fc6 100644 --- a/tensorflow/contrib/tpu/ops/infeed_ops.cc +++ b/tensorflow/contrib/tpu/ops/infeed_ops.cc @@ -41,6 +41,7 @@ REGISTER_OP("InfeedEnqueue") .Attr("dtype: type") .Attr("shape: shape = {}") .Attr("device_ordinal: int = -1") + .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() .Doc(R"doc( An op which feeds a single Tensor value into the computation. @@ -58,6 +59,7 @@ REGISTER_OP("InfeedEnqueueTuple") .Attr("dtypes: list(type)") .Attr("shapes: list(shape)") .Attr("device_ordinal: int = -1") + .SetShapeFn(shape_inference::NoOutputs) .SetIsStateful() .Doc(R"doc( An op which feeds multiple Tensor values into the computation as an XLA tuple. diff --git a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py index 78d237e6a201541b6095b101311db48b447cc477..a730d6142d890cc41f72176cf617ac0b0434192c 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/cloud_tpu_profiler/main.py @@ -25,19 +25,34 @@ import sys import tensorflow as tf +# Cloud TPU Cluster Resolvers flags.DEFINE_string( - 'service_addr', None, 'Address of TPU profiler service e.g. ' - 'localhost:8466') + 'gcp_project', None, + 'Project name for the Cloud TPU-enabled project. If not specified, we ' + 'will attempt to automatically detect the GCE project from metadata.') +flags.DEFINE_string( + 'tpu_zone', + None, + help='GCE zone where the Cloud TPU is located in. If not specified, we ' + 'will attempt to automatically detect the GCE project from metadata.') +flags.DEFINE_string('tpu_name', None, + 'Name of the Cloud TPU for Cluster Resolvers. You must ' + 'specify either this flag or --master.') + +# Tool specific parameters flags.DEFINE_string( - 'logdir', None, 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' - 'gs://tb_bucket') + 'service_addr', None, 'Address of TPU profiler service e.g. ' + 'localhost:8466, you must specify either this flag or --tpu_name.') +flags.DEFINE_string('logdir', None, + 'Path of TensorBoard log directory e.g. /tmp/tb_log, ' + 'gs://tb_bucket') flags.DEFINE_integer('duration_ms', 2000, 'Duration of tracing in ms.') -flags.DEFINE_integer( - 'num_tracing_attempts', 3, 'Automatically retry N times when no trace ' - 'event is collected.') -flags.DEFINE_boolean( - 'include_dataset_ops', True, 'Set to false to profile longer TPU ' - 'device traces.') +flags.DEFINE_integer('num_tracing_attempts', 3, + 'Automatically retry N times when no trace ' + 'event is collected.') +flags.DEFINE_boolean('include_dataset_ops', True, + 'Set to false to profile longer TPU ' + 'device traces.') FLAGS = flags.FLAGS EXECUTABLE = 'data/capture_tpu_profile' @@ -48,16 +63,35 @@ def run_main(): def main(unused_argv=None): - if not FLAGS.service_addr or not FLAGS.logdir: - sys.exit('service_addr and logdir must be provided.') + tf.logging.set_verbosity(tf.logging.INFO) + + if FLAGS.service_addr is None and FLAGS.tpu_name is None: + sys.exit('You must specify either --service_addr or --tpu_name.') + + if FLAGS.service_addr is not None: + if FLAGS.tpu_name is not None: + tf.logging.warn('Both --service_addr and --tpu_name are set. Ignoring ' + '--tpu_name and using --service_addr.') + service_addr = FLAGS.service_addr + else: + tpu_cluster_resolver = ( + tf.contrib.cluster_resolver.TPUClusterResolver( + tpu_names=[FLAGS.tpu_name], + zone=FLAGS.tpu_zone, + project=FLAGS.gcp_project)) + service_addr = tpu_cluster_resolver.get_master() + service_addr = service_addr.replace('grpc://', '').replace(':8470', ':8466') + + if not FLAGS.logdir: + sys.exit('logdir must be provided.') executable_path = os.path.join(os.path.dirname(__file__), EXECUTABLE) logdir = os.path.expandvars(os.path.expanduser(FLAGS.logdir)) cmd = [executable_path] - cmd.append('--logdir='+logdir) - cmd.append('--service_addr='+FLAGS.service_addr) - cmd.append('--duration_ms='+str(FLAGS.duration_ms)) - cmd.append('--num_tracing_attempts='+str(FLAGS.num_tracing_attempts)) - cmd.append('--include_dataset_ops='+str(FLAGS.include_dataset_ops).lower()) + cmd.append('--logdir=' + logdir) + cmd.append('--service_addr=' + service_addr) + cmd.append('--duration_ms=' + str(FLAGS.duration_ms)) + cmd.append('--num_tracing_attempts=' + str(FLAGS.num_tracing_attempts)) + cmd.append('--include_dataset_ops=' + str(FLAGS.include_dataset_ops).lower()) subprocess.call(cmd) diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py index 76f1dd2a567b570be6d1e127d1382773bf94493d..8d99835b64152629c66607e6792495eb36319eb8 100644 --- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py +++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py @@ -20,7 +20,7 @@ from __future__ import print_function from setuptools import setup -_VERSION = '1.6.0-rc0' +_VERSION = '1.6.0-rc1' CONSOLE_SCRIPTS = [ 'capture_tpu_profile=cloud_tpu_profiler.main:run_main', diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 97876216793e0e6b20b7c072cac4f575b8fd48be..14c63a79763300dcfe8d6c8e09b90f8e9c772358 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -47,7 +47,7 @@ if platform.system() != "Windows": # types are supported. _SUPPORTED_INFEED_DTYPES = set([ - dtypes.bool, dtypes.int32, dtypes.bfloat16, dtypes.float32, + dtypes.bool, dtypes.int32, dtypes.int64, dtypes.bfloat16, dtypes.float32, dtypes.complex64 ]) diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..29aea98542e31ee82249ecb2e2100c8a974a4fb7 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/datasets.py @@ -0,0 +1,192 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ====================================== +"""Library of Cloud TPU helper functions for data loading.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import functional_ops + + +def _TextLineDataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = readers.TextLineDataset(filename, buffer_size=buffer_size) + return dataset + + +def _TFRecordDataset(filename): + buffer_size = 8 * 1024 * 1024 # 8 MiB per file + dataset = readers.TFRecordDataset(filename, buffer_size=buffer_size) + return dataset + + +_FILETYPE_MAP = { + 'tfrecord': _TFRecordDataset, + 'textline': _TextLineDataset, + 'text': _TextLineDataset, +} + + +def StreamingFilesDataset(files, + filetype=None, + file_reader_job=None, + worker_job=None, + num_epochs=None, + filename_shuffle_buffer_size=None, + num_parallel_reads=None, + batch_transfer_size=None, + sloppy=None): + """StreamingFilesDataset constructs a dataset to stream from workers (GCE VM). + + Because Cloud TPUs are allocated over the network, a Cloud TPU cannot read + files local to your GCE VM. In order to train using files stored on your local + VM (e.g. on local SSD for extreme performance), use the StreamingFilesDataset + helper to generate a dataset to feed your Cloud TPU with files from your GCE + VM. + + The resulting dataset may return an OutOfRangeError if there are no files + found as a result of the fileglob expansion. + + Note: StreamingFilesDataset assumes that the session is using a + TPUClusterResolver and has therefore a worker and a coordinator job. File + loading will be done on the coordinator job. + + Args: + files: A string glob to match files, or a `tf.data.Dataset` generating file + names. + filetype: A string (one of 'tfrecord', or 'textline') or a single-argument + TensorFlow function that when given a filename returns a dataset. + file_reader_job: An optional string that corresponds to the job that should + perform the file reads. + worker_job: An optional string that corresponds to the job that should + process the tensors (i.e. your GPU or TPU worker). + num_epochs: The number of epochs through the training set that should be + generated. By default, it will repeat infinitely. + filename_shuffle_buffer_size: An optional integer whose value controls the + shuffling of the file names. If you would like to read from the files in + the same order, set to 0 or False. + num_parallel_reads: An optional integer controlling the number of files to + read from concurrently. (Set to 1 for no parallelism.) + batch_transfer_size: An optional integer controlling the batching used to + amortize the remote function invocation overhead. Set to a very large + number to increase throughput. Set to a very small number to reduce memory + consumption. Set to False to skip batching. + sloppy: (Optional.) If `True`, read input data as fast as possible, without + maintaining a deterministic order. Defaults to `False`. + Returns: + A `tf.data.Dataset` with an infinite stream of elements generated by a + parallel interleaving of the set of files matched (or generated) by `files` + with a type is the output of the dataset specified by `filetype`. + + Raises: + ValueError: if any argument is not of the expected type. + """ + if filetype is None: + filetype = 'tfrecord' + + if isinstance(filetype, str): + if filetype not in _FILETYPE_MAP: + raise ValueError('Unexpected filetype: %s' % filetype) + reader_fn = _FILETYPE_MAP[filetype] + elif callable(filetype): + reader_fn = filetype + else: + raise ValueError('filetype should be a string or a callable') + + file_reader_job = file_reader_job or 'coordinator' + + worker_job = worker_job or 'worker' + + if filename_shuffle_buffer_size is None: + filename_shuffle_buffer_size = 4096 + + num_parallel_reads = num_parallel_reads or 8 + + if batch_transfer_size is None: + batch_transfer_size = 1024 + + if sloppy is None: + sloppy = False + + with ops.device('/job:%s' % file_reader_job): + if isinstance(files, str): + source_dataset = dataset_ops.Dataset.list_files(files) + elif isinstance(files, dataset_ops.Dataset): + source_dataset = files + else: + raise ValueError('files was not a string or a dataset: %s' % files) + + if filename_shuffle_buffer_size: + source_dataset = source_dataset.shuffle( + buffer_size=filename_shuffle_buffer_size) + + # NOTE: We perform the `repeat` on the source dataset, because the output + # dataset does not currently have enough information to recreate an iterator + # over the source dataset when it reaches the end. + source_dataset = source_dataset.repeat(num_epochs) + + source_dataset = source_dataset.apply( + interleave_ops.parallel_interleave( + reader_fn, cycle_length=num_parallel_reads, sloppy=sloppy)) + + if batch_transfer_size: + # Note: we can safely call batch_and_drop_remainder because we have an + # infinite stream of TFRecords. + source_dataset = source_dataset.apply( + batching.batch_and_drop_remainder(batch_transfer_size)) + + source_dataset = source_dataset.prefetch(1) + + source_iterator = source_dataset.make_one_shot_iterator() + source_handle = source_iterator.string_handle() + + @function.Defun(dtypes.string) + def LoadingFunc(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, source_dataset.output_types, source_dataset.output_shapes) + return remote_iterator.get_next() + + def MapFn(unused_input): + return functional_ops.remote_call( + args=[source_handle], + Tout=[dtypes.string], + f=LoadingFunc, + target='/job:%s/replica:0/task:0/cpu:0' % file_reader_job) + + with ops.device('/job:%s' % worker_job): + # TODO(saeta,mrry): Switch to using _GeneratorDataset. + + # identity = lambda x: x + # dummy = constant_op.constant(0) + # output_dataset = dataset_ops._GeneratorDataset(dummy, identity, MapFn, + # identity) + + output_dataset = dataset_ops.Dataset.range(2).repeat().map(MapFn) + output_dataset = output_dataset.prefetch(1) + + if batch_transfer_size: + # Undo the batching used during the transfer. + output_dataset = output_dataset.apply(batching.unbatch()).prefetch(1) + + return output_dataset diff --git a/tensorflow/contrib/tpu/python/tpu/datasets_test.py b/tensorflow/contrib/tpu/python/tpu/datasets_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2c4079779215328c671cbb7fefd356c926fad4f4 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/datasets_test.py @@ -0,0 +1,181 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TPU datasets tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.tpu.python.tpu import datasets +from tensorflow.core.protobuf import cluster_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.lib.io import python_io +from tensorflow.python.platform import test +from tensorflow.python.training import server_lib +from tensorflow.python.util import compat + +_NUM_FILES = 10 +_NUM_ENTRIES = 200 + + +class DatasetsTest(test.TestCase): + + def setUp(self): + super(DatasetsTest, self).setUp() + self._coord = server_lib.Server.create_local_server() + self._worker = server_lib.Server.create_local_server() + + self._cluster_def = cluster_pb2.ClusterDef() + worker_job = self._cluster_def.job.add() + worker_job.name = 'worker' + worker_job.tasks[0] = self._worker.target[len('grpc://'):] + coord_job = self._cluster_def.job.add() + coord_job.name = 'coordinator' + coord_job.tasks[0] = self._coord.target[len('grpc://'):] + + session_config = config_pb2.ConfigProto(cluster_def=self._cluster_def) + + self._sess = session.Session(self._worker.target, config=session_config) + + def testTextLineDataset(self): + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'text_line.%d.txt' % i) + contents = [] + for j in range(_NUM_ENTRIES): + contents.append(compat.as_bytes('%d: %d' % (i, j))) + with open(filename, 'wb') as f: + f.write(b'\n'.join(contents)) + all_contents.extend(contents) + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'text_line.*.txt'), filetype='text') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testTFRecordDataset(self): + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) + writer = python_io.TFRecordWriter(filename) + for j in range(_NUM_ENTRIES): + record = compat.as_bytes('Record %d of file %d' % (j, i)) + writer.write(record) + all_contents.append(record) + writer.close() + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'tf_record*'), filetype='tfrecord') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testTFRecordDatasetFromDataset(self): + filenames = [] + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'tf_record.%d' % i) + filenames.append(filename) + writer = python_io.TFRecordWriter(filename) + for j in range(_NUM_ENTRIES): + record = compat.as_bytes('Record %d of file %d' % (j, i)) + writer.write(record) + all_contents.append(record) + writer.close() + + filenames = dataset_ops.Dataset.from_tensor_slices(filenames) + + dataset = datasets.StreamingFilesDataset(filenames, filetype='tfrecord') + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testArbitraryReaderFunc(self): + + def MakeRecord(i, j): + return compat.as_bytes('%04d-%04d' % (i, j)) + + record_bytes = len(MakeRecord(10, 200)) + + all_contents = [] + for i in range(_NUM_FILES): + filename = os.path.join(self.get_temp_dir(), 'fixed_length.%d' % i) + with open(filename, 'wb') as f: + for j in range(_NUM_ENTRIES): + record = MakeRecord(i, j) + f.write(record) + all_contents.append(record) + + def FixedLengthFile(filename): + return readers.FixedLengthRecordDataset(filename, record_bytes) + + dataset = datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), 'fixed_length*'), + filetype=FixedLengthFile) + + iterator = dataset.make_initializable_iterator() + self._sess.run(iterator.initializer) + get_next = iterator.get_next() + + retrieved_values = [] + for _ in range(2 * len(all_contents)): + retrieved_values.append(compat.as_bytes(self._sess.run(get_next))) + + self.assertEqual(set(all_contents), set(retrieved_values)) + + def testUnexpectedFiletypeString(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), '*'), filetype='foo') + + def testUnexpectedFiletypeType(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset( + os.path.join(self.get_temp_dir(), '*'), filetype=3) + + def testUnexpectedFilesType(self): + with self.assertRaises(ValueError): + datasets.StreamingFilesDataset(123, filetype='tfrecord') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 644070218214643923b9ca3ee138615ec568e8b5..7ceb4069cf011d88b6fb4586d7e80acbacf9aebe 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -26,6 +26,7 @@ import os import numpy as np from tensorflow.contrib.tpu.python.tpu import util as util_lib +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import run_config as run_config_lib from tensorflow.python.platform import tf_logging as logging @@ -140,6 +141,7 @@ class RunConfig(run_config_lib.RunConfig): tpu_config=None, evaluation_master=None, master=None, + cluster=None, **kwargs): """Constructs a RunConfig. @@ -148,15 +150,26 @@ class RunConfig(run_config_lib.RunConfig): evaluation_master: a string. The address of the master to use for eval. Defaults to master if not set. master: a string. The address of the master to use for training. + cluster: a ClusterResolver **kwargs: keyword config parameters. + + Raises: + ValueError: if cluster is not None and the provided session_config has a + cluster_def already. """ super(RunConfig, self).__init__(**kwargs) self._tpu_config = tpu_config or TPUConfig() + self._cluster = cluster # If user sets master and/or evaluation_master explicilty, including empty # string '', take it. Otherwise, take the values set by parent class. if master is not None: + if cluster is not None: + raise ValueError('Both master and cluster are set.') self._master = master + else: + if cluster: + self._master = cluster.master() if evaluation_master is not None: self._evaluation_master = evaluation_master @@ -170,6 +183,20 @@ class RunConfig(run_config_lib.RunConfig): # evaluation_master to master, unless user overwrites it. self._evaluation_master = self._master + # Set the ClusterSpec to use + if cluster: + self._cluster_spec = cluster.cluster_spec() + + # Merge the cluster_def into the ConfigProto. + if self._session_config is None: # pylint: disable=access-member-before-definition + self._session_config = config_pb2.ConfigProto(allow_soft_placement=True) + if self._session_config.HasField('cluster_def'): + raise ValueError( + 'You cannot provide a ClusterResolver and ' + 'session_config.cluster_def.') + self._session_config.cluster_def.CopyFrom( + self._cluster_spec.as_cluster_def()) + @property def evaluation_master(self): return self._evaluation_master @@ -182,6 +209,10 @@ class RunConfig(run_config_lib.RunConfig): def tpu_config(self): return self._tpu_config + @property + def cluster(self): + return self._cluster + def replace(self, **kwargs): if 'tpu_config' not in kwargs: return super(RunConfig, self).replace(**kwargs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py index 8c65018d1490728476f1e902eb9ca619b0fe9188..c5c46ea741ea64ca37089431f8ed66cad7bc31fb 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py @@ -106,7 +106,9 @@ class _TPUContext(object): # pylint: disable=protected-access tpu_system_metadata = ( tpu_system_metadata_lib._query_tpu_system_metadata( - master, query_topology=self.model_parallelism_enabled)) + master, + run_config=self._config, + query_topology=self.model_parallelism_enabled)) self._lazy_tpu_system_metadata_dict[master] = tpu_system_metadata return tpu_system_metadata @@ -409,6 +411,22 @@ class _TPUContext(object): 'Tensorflow master address and TPU worker(s). Available devices ' 'are {}.'.format(tpu_system_metadata.devices)) + if self._config.tpu_config.num_shards: + user_provided_num_replicas = self._config.tpu_config.num_shards + if user_provided_num_replicas != num_replicas: + message = ( + 'TPUConfig.num_shards is not set correctly. According to TPU ' + 'system metadata for Tensorflow master ({}): num_replicas should ' + 'be ({}), got ({}). For non-model-parallelism, num_replicas should ' + 'be the total num of TPU cores in the system. For ' + 'model-parallelism, the total number of TPU cores should be ' + 'product(computation_shape) * num_replicas. Please set it ' + 'accordingly or leave it as `None`'.format( + self._get_master_address(), num_replicas, + user_provided_num_replicas)) + + raise ValueError(message) + if mode == model_fn_lib.ModeKeys.TRAIN: if self._train_batch_size % num_replicas != 0: raise ValueError( diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index ff53fe4f5d0e219f56d77d3476640bb023c7535a..1b2eda1caa0fa2779834d65b5a49121d9cc0af56 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -1763,6 +1763,9 @@ class TPUEstimator(estimator_lib.Estimator): if 'config' in input_fn_args: kwargs['config'] = config + if 'mode' in input_fn_args: + kwargs['mode'] = mode + with self._ctx.with_mode(mode) as ctx: # Setting the batch size in params first. This helps user to have same # input_fn for use_tpu=True/False. diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py index e003313667dcacbd6d10951ab45e988048f8f50e..493d1848c072caa5254fc87c67badc2e99ec16ee 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_system_metadata.py @@ -45,7 +45,8 @@ _TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [ ]) -def _query_tpu_system_metadata(master_address, query_topology=False): +def _query_tpu_system_metadata(master_address, run_config, + query_topology=False): """Automatically detects the TPU system metadata in the system.""" tpu_core_count = 0 devices = [] @@ -59,8 +60,8 @@ def _query_tpu_system_metadata(master_address, query_topology=False): with ops.Graph().as_default(): with session_lib.Session( master_address, - config=config_pb2.ConfigProto( - operation_timeout_in_ms=_PINGING_MASTER_TIMEOUT_IN_MS)) as sess: + config=_get_session_config_with_timeout( + _PINGING_MASTER_TIMEOUT_IN_MS, run_config)) as sess: devices = sess.list_devices() for device in devices: match = _TPU_DEVICE_REG.match(device.name) @@ -104,7 +105,7 @@ def _query_tpu_system_metadata(master_address, query_topology=False): 'TPU worker has some problems. Available devices: {}'.format( master_address, devices)) - topology = _obtain_topology(master_address) + topology = _obtain_topology(master_address, run_config) metadata = _TPUSystemMetadata( num_cores=tpu_core_count, @@ -113,19 +114,26 @@ def _query_tpu_system_metadata(master_address, query_topology=False): topology=topology, devices=devices) - msg = 'Found TPU system %s' if tpu_core_count else 'Failed to find TPU: %s' - logging.info(msg, metadata) + if tpu_core_count: + logging.info('Found TPU system:') + logging.info('*** Num TPU Cores: %d', metadata.num_cores) + logging.info('*** Num TPU Workers: %d', metadata.num_hosts) + logging.info('*** Num TPU Cores Per Worker: %d', + metadata.num_of_cores_per_host) + logging.info('*** Available Devices: %s', metadata.devices) + else: + logging.info('Failed to find TPU: %s', metadata) return metadata -def _obtain_topology(master_address): +def _obtain_topology(master_address, run_config): try: logging.info('Initializing TPU system (master: %s) to fetch topology ' 'for model parallelism. This might take a while.', master_address) with ops.Graph().as_default(): - session_config = config_pb2.ConfigProto( - operation_timeout_in_ms=_INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS) + session_config = _get_session_config_with_timeout( + _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS, run_config) with session_lib.Session( master_address, config=session_config) as sess: topology = sess.run(tpu.initialize_system()) @@ -137,3 +145,11 @@ def _obtain_topology(master_address): master_address)) +def _get_session_config_with_timeout(timeout_in_secs, run_config): + cluster_def = None + if run_config.session_config and run_config.session_config.cluster_def.job: + cluster_def = run_config.session_config.cluster_def + + config = config_pb2.ConfigProto( + operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def) + return config diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index fdfd27d6a414933b0bec824bae512c45dac24d3c..95e051e3b5bb9f8075e66891a45c64a27bca68d1 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -358,6 +358,8 @@ class HParams(object): ``` """ + _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. + def __init__(self, hparam_def=None, model_structure=None, **kwargs): """Create an instance of `HParams` from keyword arguments. diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index d1fb9f444514fee4ce339d4308da0d583ae36aa0..1893967cdd0034bcff52c84f4db0bf1e2e3334f4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -480,6 +480,7 @@ tf_cuda_library( "framework/type_index.h", "framework/type_traits.h", "framework/types.h", + "framework/visitable_allocator.h", "public/version.h", "util/activation_mode.h", "util/bcast.h", @@ -988,22 +989,15 @@ filegroup( # Core sources for Android builds. filegroup( - name = "mobile_srcs", + name = "mobile_srcs_no_runtime", srcs = [ ":proto_text_srcs_all", - "//tensorflow/core/kernels:android_srcs", "//tensorflow/core/platform/default/build_config:android_srcs", - "//tensorflow/core/util/ctc:android_srcs", - "//tensorflow/core/util/tensor_bundle:android_srcs", ] + glob( [ "client/**/*.cc", - "common_runtime/**/*.h", - "common_runtime/**/*.cc", "framework/**/*.h", "framework/**/*.cc", - "graph/**/*.h", - "graph/**/*.cc", "lib/**/*.h", "lib/**/*.cc", "platform/**/*.h", @@ -1019,7 +1013,6 @@ filegroup( "**/*main.cc", "debug/**/*", "framework/op_gen_*", - "graph/dot.*", "lib/jpeg/**/*", "lib/png/**/*", "lib/gif/**/*", @@ -1036,13 +1029,52 @@ filegroup( "platform/stream_executor.*", "platform/windows/**/*", "user_ops/**/*.cu.cc", + "util/ctc/*.h", + "util/ctc/*.cc", + "util/tensor_bundle/*.h", + "util/tensor_bundle/*.cc", + "common_runtime/gpu/**/*", + "common_runtime/gpu_device_factory.*", + ], + ), + visibility = ["//visibility:public"], +) + +filegroup( + name = "mobile_srcs_only_runtime", + srcs = [ + "//tensorflow/core/kernels:android_srcs", + "//tensorflow/core/util/ctc:android_srcs", + "//tensorflow/core/util/tensor_bundle:android_srcs", + ] + glob( + [ + "common_runtime/**/*.h", + "common_runtime/**/*.cc", + "graph/**/*.h", + "graph/**/*.cc", + ], + exclude = [ + "**/*test.*", + "**/*testutil*", + "**/*testlib*", + "**/*main.cc", "common_runtime/gpu/**/*", "common_runtime/gpu_device_factory.*", + "graph/dot.*", ], ), visibility = ["//visibility:public"], ) +filegroup( + name = "mobile_srcs", + srcs = [ + ":mobile_srcs_no_runtime", + ":mobile_srcs_only_runtime", + ], + visibility = ["//visibility:public"], +) + # Native library support for Android applications. Does not contain # operators, use :android_tensorflow_lib if you want full operator # support. @@ -1335,6 +1367,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "framework/kernel_def_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "framework/kernel_def.proto", + visibility = ["//visibility:public"], +) + tf_pyclif_proto_library( name = "framework/node_def_pyclif", proto_lib = ":protos_all_cc", @@ -1774,6 +1813,7 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", "framework/variant.h", + "framework/visitable_allocator.h", "platform/variant_coding.h", "util/command_line_flags.h", "util/env_var.h", @@ -2069,7 +2109,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/stats_publisher_interface.h", "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", - "common_runtime/visitable_allocator.h", "graph/gradients.h", "graph/quantize_training.h", ] + if_mkl(["graph/mkl_graph_util.h"]) @@ -2272,6 +2311,8 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_cudamalloc_allocator.h", "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", + "common_runtime/gpu/gpu_id.h", + "common_runtime/gpu/gpu_id_manager.h", "common_runtime/gpu/gpu_id_utils.h", "common_runtime/gpu/gpu_init.h", "common_runtime/gpu/gpu_managed_allocator.h", @@ -3475,6 +3516,7 @@ tf_cc_tests( "ops/parsing_ops_test.cc", "ops/random_ops_test.cc", "ops/set_ops_test.cc", + "ops/shape_function_test.cc", "ops/sparse_ops_test.cc", "ops/spectral_ops_test.cc", "ops/state_ops_test.cc", @@ -3633,6 +3675,18 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +alias( + name = "android_srcs_no_runtime", + actual = ":mobile_srcs_no_runtime", + visibility = ["//visibility:public"], +) + +alias( + name = "android_srcs_only_runtime", + actual = ":mobile_srcs_only_runtime", + visibility = ["//visibility:public"], +) + alias( name = "android_srcs", actual = ":mobile_srcs", diff --git a/tensorflow/core/api_def/base_api/api_def_ConsumeMutexLock.pbtxt b/tensorflow/core/api_def/base_api/api_def_ConsumeMutexLock.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..b9db8274dea5d904dbbc687927673e0c7f7fa649 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ConsumeMutexLock.pbtxt @@ -0,0 +1,19 @@ +op { + graph_op_name: "ConsumeMutexLock" + in_arg { + name: "mutex_lock" + description: < [1, 2, 4, 7, 8] +idx ==> [0, 0, 1, 2, 2, 2, 3, 4, 4] +count ==> [2, 1, 3, 1, 2] +``` + +For an `2-D` tensor `x` with `axis = 0`: + +``` +# tensor 'x' is [[1, 0, 0], +# [1, 0, 0], +# [2, 0, 0]] +y, idx, count = unique_with_counts(x, axis=0) +y ==> [[1, 0, 0], + [2, 0, 0]] +idx ==> [0, 0, 1] +count ==> [2, 1] +``` + +For an `2-D` tensor `x` with `axis = 1`: + +``` +# tensor 'x' is [[1, 0, 0], +# [1, 0, 0], +# [2, 0, 0]] +y, idx, count = unique_with_counts(x, axis=1) +y ==> [[1, 0], + [1, 0], + [2, 0]] +idx ==> [0, 1, 1] +count ==> [1, 2] +``` +END +} diff --git a/tensorflow/core/api_def/python_api/api_def_UniqueWithCounts.pbtxt b/tensorflow/core/api_def/python_api/api_def_UniqueWithCounts.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..71b35eaab5f4a251ebebf9ddb7baf2ecd0a12401 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_UniqueWithCounts.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "UniqueWithCounts" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_UniqueWithCountsV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_UniqueWithCountsV2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..7876e55cf3e2c24e19507cefb01f9f61abd0a2bc --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_UniqueWithCountsV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "UniqueWithCountsV2" + visibility: HIDDEN +} diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index b8e773503c7a2f8024e8a6f58247ad343a762f71..e34945dd48a1e54e4ae82dd7ea9959f39a97f2c2 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -23,7 +23,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/allocator_retry.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/visitable_allocator.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index df6f4b88773fb1a72100d1c223276a06b857a908..ecbffcbf6c4030bde82f2abe0e7779bf9c5a9870 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1250,7 +1250,7 @@ Status DirectSession::GetOrCreateExecutors( item->device = device; Executor* executor; TF_RETURN_IF_ERROR( - NewLocalExecutor(params, partition_graph.release(), &executor)); + NewLocalExecutor(params, std::move(partition_graph), &executor)); item->executor.reset(executor); } diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 6998cbecee5695832185c71ec745525abc50c38e..b06b75d6585f01640374eb7ab9842bf441cf9411 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -332,8 +332,8 @@ class GraphView { class ExecutorImpl : public Executor { public: - ExecutorImpl(const LocalExecutorParams& p, const Graph* g) - : params_(p), graph_(g), gview_() { + ExecutorImpl(const LocalExecutorParams& p, std::unique_ptr g) + : params_(p), graph_(std::move(g)), gview_() { CHECK(p.create_kernel != nullptr); CHECK(p.delete_kernel != nullptr); } @@ -348,7 +348,6 @@ class ExecutorImpl : public Executor { for (auto fiter : frame_info_) { delete fiter.second; } - delete graph_; } Status Initialize(); @@ -412,7 +411,7 @@ class ExecutorImpl : public Executor { // Owned. LocalExecutorParams params_; - const Graph* graph_; + std::unique_ptr graph_; GraphView gview_; // A cached value of params_ @@ -605,11 +604,11 @@ void GetMaxPendingCounts(const Node* n, size_t* max_pending, } Status ExecutorImpl::Initialize() { - gview_.Initialize(graph_); + gview_.Initialize(graph_.get()); // Build the information about frames in this subgraph. ControlFlowInfo cf_info; - TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_, &cf_info)); + TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph_.get(), &cf_info)); // Cache this value so we make this virtual function call once, rather // that O(# steps * # nodes per step) times. @@ -676,9 +675,9 @@ Status ExecutorImpl::Initialize() { // Initialize PendingCounts only after item->pending_id is initialized for // all nodes. - InitializePending(graph_, cf_info); + InitializePending(graph_.get(), cf_info); - return gview_.SetAllocAttrs(graph_, params_.device); + return gview_.SetAllocAttrs(graph_.get(), params_.device); } Status GraphView::SetAllocAttrs(const Graph* g, const Device* device) { @@ -1415,7 +1414,7 @@ void ExecutorImpl::InitializePending(const Graph* graph, } void ExecutorState::RunAsync(Executor::DoneCallback done) { - const Graph* graph = impl_->graph_; + const Graph* graph = impl_->graph_.get(); TaggedNodeSeq ready; // Ask the device to fill in the device context map. @@ -2606,9 +2605,10 @@ void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { } // end namespace -Status NewLocalExecutor(const LocalExecutorParams& params, const Graph* graph, +Status NewLocalExecutor(const LocalExecutorParams& params, + std::unique_ptr graph, Executor** executor) { - ExecutorImpl* impl = new ExecutorImpl(params, graph); + ExecutorImpl* impl = new ExecutorImpl(params, std::move(graph)); const Status s = impl->Initialize(); if (s.ok()) { *executor = impl; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 3fd932da5b6c44833ba940351dad6cf373ffa05c..adf80a2417e2a86e874dd1d1068a1bbb611ff882 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -122,9 +122,8 @@ class Executor { // Creates an Executor that computes the given "graph". // -// If successful, returns the constructed executor in "*executor". The -// caller keeps the ownership of "device". The returned executor takes -// the ownership of "graph". Otherwise, returns an error status. +// If successful, returns the constructed executor in "*executor". Otherwise, +// returns an error status. // // "params" provides a set of context for the executor. We expect that // different context would provide different implementations. @@ -143,7 +142,8 @@ struct LocalExecutorParams { Executor::Args::NodeOutputsCallback node_outputs_cb; }; ::tensorflow::Status NewLocalExecutor(const LocalExecutorParams& params, - const Graph* graph, Executor** executor); + std::unique_ptr graph, + Executor** executor); // A class to help run multiple executors in parallel and wait until // all of them are complete. diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index d349d2bb1251d8b31d8e432ad8c357d6f0a81389..3e937ceb640554be3a2578decdb336d0e58c197f 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -42,11 +42,8 @@ limitations under the License. namespace tensorflow { // A few string constant used throughout this module. -// -// TODO(zhifengc): Dedup some of these constants into -// framework/function.h -static constexpr const char* const kArgOp = "_Arg"; -static constexpr const char* const kRetOp = "_Retval"; +static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; static constexpr const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp; static constexpr const char* const kNodeLabel = "Func"; @@ -177,6 +174,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime { } Device* device() override { return device_; } + const DeviceMgr* device_mgr() const override { return device_mgr_; } Env* env() override { return env_; } int graph_def_version() override { return graph_def_version_; } @@ -631,7 +629,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) { }; Graph* graph = g.get(); Executor* exec; - TF_RETURN_IF_ERROR(NewLocalExecutor(params, g.release(), &exec)); + TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(g), &exec)); { // Guard item since it is already inserted in items_. @@ -1580,9 +1578,6 @@ Status FunctionDefToBodyHelper( // Call BuildControlFlowInfo to validate that this function body has // well-formed control flow. - // NOTE(skyewm): this is usually done in Partition(), but we don't partition - // function bodies. This should be removed if function bodies ever go through - // the Partition() path. std::vector dummy; TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy)); diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 8b051462990fc3abc5b864f644274bb8b2211191..63ad0d231c28a5af144b61e967a73e8ecfe6049a 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -71,11 +71,11 @@ class FunctionTest : public ::testing::Test { arg_types_ = result.arg_types; ret_types_ = result.ret_types; - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr g(new Graph(OpRegistry::Global())); GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = false; - TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g)); + TF_CHECK_OK(ConvertNodeDefsToGraph(opts, result.nodes, g.get())); const int version = g->versions().producer(); LocalExecutorParams params; @@ -89,7 +89,7 @@ class FunctionTest : public ::testing::Test { DeleteNonCachedKernel(kernel); }; Executor* exec; - TF_CHECK_OK(NewLocalExecutor(params, g, &exec)); + TF_CHECK_OK(NewLocalExecutor(params, std::move(g), &exec)); exec_.reset(exec); } diff --git a/tensorflow/core/common_runtime/function_testlib.cc b/tensorflow/core/common_runtime/function_testlib.cc index 87c2476b04af7300d7138d59b3261496eb38c482..87733ed2dbe931c6bb64fd065d2691072d4eced0 100644 --- a/tensorflow/core/common_runtime/function_testlib.cc +++ b/tensorflow/core/common_runtime/function_testlib.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function_testlib.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" @@ -39,7 +40,9 @@ class FindDeviceOpKernel : public OpKernel { REGISTER_KERNEL_BUILDER(Name("FindDeviceOp").Device(tensorflow::DEVICE_CPU), FindDeviceOpKernel); -REGISTER_OP("FindDeviceOp").Output("device_name: string"); +REGISTER_OP("FindDeviceOp") + .Output("device_name: string") + .SetShapeFn(shape_inference::UnknownShape); FunctionDef FindDevice() { return FDH::Define( diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h index 208697361d2dfc4f3b8290ea511d15c9bd86857b..0a586344ccf2228a23059d68e7aa2d7a8f4eadba 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h @@ -19,7 +19,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/visitable_allocator.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h index adce3a84368ced958002443721016778cb6df028..0db08dc9759c9306ebd99b4acf4967128ef04895 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h @@ -21,7 +21,7 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/visitable_allocator.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 15ff15fd5ab28605c4ab0904e62305edc3815adb..8357cc5a7201b3b590c6965648eed72116167459 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -1013,21 +1013,34 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, GpuIdUtil::CheckValidTfGpuId(tf_gpu_id); CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id); int numa_node = dev_locality.numa_node(); - Bytes allocated_bytes = static_cast(memory_limit); gpu::StreamExecutor* se = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); const gpu::DeviceDescription& desc = se->GetDeviceDescription(); - LOG(INFO) << "Creating TensorFlow device (" << device_name << " with " - << (memory_limit >> 20) << " MB memory) -> physical GPU (" - << GetShortDeviceDescription(cuda_gpu_id, desc) << ")"; ProcessState* process_state = ProcessState::singleton(); + Allocator* gpu_allocator = process_state->GetGPUAllocator( + options.config.gpu_options(), tf_gpu_id, memory_limit); + if (gpu_allocator == nullptr) { + return errors::Internal("Failed to get memory allocator for TF GPU ", + tf_gpu_id.value(), " with ", memory_limit, + " bytes of memory."); + } + AllocatorStats stats; + gpu_allocator->GetStats(&stats); + // 'memory_limit' is the required memory size, but if the allocator with given + // tf_gpu_id was created before, we'll use it instead of creating a new one + // (as TF gpu device is a shared resource), in which case the actual memory + // limit represented by 'stats.bytes_limit' used by that allocator may be + // different (which should be an error). + // + // TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit. BaseGPUDevice* gpu_device = CreateGPUDevice( - options, device_name, allocated_bytes, dev_locality, tf_gpu_id, - GetShortDeviceDescription(cuda_gpu_id, desc), - process_state->GetGPUAllocator(options.config.gpu_options(), tf_gpu_id, - memory_limit), + options, device_name, static_cast(stats.bytes_limit), dev_locality, + tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator, process_state->GetCPUAllocator(numa_node)); + LOG(INFO) << "Created TensorFlow device (" << device_name << " with " + << (stats.bytes_limit >> 20) << " MB memory) -> physical GPU (" + << GetShortDeviceDescription(cuda_gpu_id, desc) << ")"; TF_RETURN_IF_ERROR(gpu_device->Init(options)); devices->push_back(gpu_device); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index c88daa8ff87589a3fc48f4c7693d073d6adf9a5a..d817c7dd1f3af5656e48c3b2a0420270a7938447 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -68,7 +68,7 @@ class BaseGPUDevice : public LocalDevice { const TensorReferenceVector& tensor_refs) override; Status FillContextMap(const Graph* graph, - DeviceContextMap* device_context_map); + DeviceContextMap* device_context_map) override; void Compute(OpKernel* op_kernel, OpKernelContext* context) override; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index b56823204afe8ee52e0ea376b1a79d91d6932fa0..f3935f6ba26c49a9967d0848bfb6d965c73d2fab 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -18,42 +18,48 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_device.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" +#include "tensorflow/core/common_runtime/gpu/process_state.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { -namespace { const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0"; -static SessionOptions MakeSessionOptions( - const string& visible_device_list = "", - double per_process_gpu_memory_fraction = 0, int gpu_device_count = 1, - const std::vector>& memory_limit_mb = {}) { - SessionOptions options; - ConfigProto* config = &options.config; - (*config->mutable_device_count())["GPU"] = gpu_device_count; - GPUOptions* gpu_options = config->mutable_gpu_options(); - gpu_options->set_visible_device_list(visible_device_list); - gpu_options->set_per_process_gpu_memory_fraction( - per_process_gpu_memory_fraction); - for (const auto& v : memory_limit_mb) { - auto virtual_devices = - gpu_options->mutable_experimental()->add_virtual_devices(); - for (float mb : v) { - virtual_devices->add_memory_limit_mb(mb); +class GPUDeviceTest : public ::testing::Test { + public: + void TearDown() { ProcessState::singleton()->TestOnlyReset(); } + + protected: + static SessionOptions MakeSessionOptions( + const string& visible_device_list = "", + double per_process_gpu_memory_fraction = 0, int gpu_device_count = 1, + const std::vector>& memory_limit_mb = {}) { + SessionOptions options; + ConfigProto* config = &options.config; + (*config->mutable_device_count())["GPU"] = gpu_device_count; + GPUOptions* gpu_options = config->mutable_gpu_options(); + gpu_options->set_visible_device_list(visible_device_list); + gpu_options->set_per_process_gpu_memory_fraction( + per_process_gpu_memory_fraction); + for (const auto& v : memory_limit_mb) { + auto virtual_devices = + gpu_options->mutable_experimental()->add_virtual_devices(); + for (float mb : v) { + virtual_devices->add_memory_limit_mb(mb); + } } + return options; } - return options; -} -static bool StartsWith(const string& lhs, const string& rhs) { - if (rhs.length() > lhs.length()) return false; - return lhs.substr(0, rhs.length()) == rhs; -} + static bool StartsWith(const string& lhs, const string& rhs) { + if (rhs.length() > lhs.length()) return false; + return lhs.substr(0, rhs.length()) == rhs; + } +}; -TEST(GPUDeviceTest, FailedToParseVisibleDeviceList) { +TEST_F(GPUDeviceTest, FailedToParseVisibleDeviceList) { SessionOptions opts = MakeSessionOptions("0,abc"); std::vector devices; Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( @@ -63,7 +69,7 @@ TEST(GPUDeviceTest, FailedToParseVisibleDeviceList) { << status; } -TEST(GPUDeviceTest, InvalidGpuId) { +TEST_F(GPUDeviceTest, InvalidGpuId) { SessionOptions opts = MakeSessionOptions("100"); std::vector devices; Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( @@ -74,7 +80,7 @@ TEST(GPUDeviceTest, InvalidGpuId) { << status; } -TEST(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) { +TEST_F(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) { SessionOptions opts = MakeSessionOptions("0,0"); std::vector devices; Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( @@ -85,7 +91,7 @@ TEST(GPUDeviceTest, DuplicateEntryInVisibleDeviceList) { << status; } -TEST(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) { +TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) { SessionOptions opts = MakeSessionOptions("0", 0.1, 1, {{}}); std::vector devices; Status status = DeviceFactory::GetFactory("GPU")->CreateDevices( @@ -96,7 +102,7 @@ TEST(GPUDeviceTest, VirtualDeviceConfigConflictsWithMemoryFractionSettings) { << status; } -TEST(GPUDeviceTest, GpuDeviceCountTooSmall) { +TEST_F(GPUDeviceTest, GpuDeviceCountTooSmall) { // device_count is 0, but with one entry in visible_device_list and one // (empty) VirtualDevices messages. SessionOptions opts = MakeSessionOptions("0", 0, 0, {{}}); @@ -109,7 +115,7 @@ TEST(GPUDeviceTest, GpuDeviceCountTooSmall) { << status; } -TEST(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) { +TEST_F(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) { // Single entry in visible_device_list with two (empty) VirtualDevices // messages. SessionOptions opts = MakeSessionOptions("0", 0, 8, {{}, {}}); @@ -122,7 +128,7 @@ TEST(GPUDeviceTest, NotEnoughGpuInVisibleDeviceList) { << status; } -TEST(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) { +TEST_F(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) { // This test requires at least two visible GPU hardware. if (GPUMachineManager()->VisibleDeviceCount() < 2) return; // Three entries in visible_device_list with two (empty) VirtualDevices @@ -139,7 +145,7 @@ TEST(GPUDeviceTest, VirtualDeviceConfigConflictsWithVisibleDeviceList) { << status; } -TEST(GPUDeviceTest, EmptyVirtualDeviceConfig) { +TEST_F(GPUDeviceTest, EmptyVirtualDeviceConfig) { // It'll create single virtual device when the virtual device config is empty. SessionOptions opts = MakeSessionOptions("0"); std::vector devices; @@ -150,7 +156,7 @@ TEST(GPUDeviceTest, EmptyVirtualDeviceConfig) { for (auto d : devices) delete d; } -TEST(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) { +TEST_F(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) { // It'll create single virtual device for the gpu in question when // memory_limit_mb is unset. SessionOptions opts = MakeSessionOptions("0", 0, 1, {{}}); @@ -162,7 +168,7 @@ TEST(GPUDeviceTest, SingleVirtualDeviceWithNoMemoryLimit) { for (auto d : devices) delete d; } -TEST(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) { +TEST_F(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) { SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123}}); std::vector devices; TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( @@ -172,7 +178,7 @@ TEST(GPUDeviceTest, SingleVirtualDeviceWithMemoryLimit) { for (auto d : devices) delete d; } -TEST(GPUDeviceTest, MultipleVirtualDevices) { +TEST_F(GPUDeviceTest, MultipleVirtualDevices) { SessionOptions opts = MakeSessionOptions("0", 0, 1, {{123, 456}}); std::vector devices; TF_CHECK_OK(DeviceFactory::GetFactory("GPU")->CreateDevices( @@ -195,7 +201,6 @@ TEST(GPUDeviceTest, MultipleVirtualDevices) { for (auto d : devices) delete d; } -} // namespace } // namespace tensorflow #endif diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc index 207afdca75642b14c1617c8abae4fd5e9916f020..7dfff3269cf91582adf783dcd15dd55d1c4e1451 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc @@ -18,7 +18,10 @@ limitations under the License. #include #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" namespace tensorflow { @@ -27,8 +30,8 @@ namespace { class TfToCudaGpuIdMap { public: static TfToCudaGpuIdMap* singleton() { - static auto* manager = new TfToCudaGpuIdMap; - return manager; + static auto* id_map = new TfToCudaGpuIdMap; + return id_map; } void InsertOrDie(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) @@ -47,18 +50,41 @@ class TfToCudaGpuIdMap { } } - int32 FindOrDie(TfGpuId tf_gpu_id) const LOCKS_EXCLUDED(mu_) { + CudaGpuId FindOrDie(TfGpuId tf_gpu_id) const LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); + return FindOrDieLocked(tf_gpu_id); + } + + bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const + LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + if (id_map_.count(tf_gpu_id.value()) == 0) return false; + *cuda_gpu_id = FindOrDieLocked(tf_gpu_id); + return true; + } + + private: + TfToCudaGpuIdMap() = default; + + CudaGpuId FindOrDieLocked(TfGpuId tf_gpu_id) const + EXCLUSIVE_LOCKS_REQUIRED(mu_) { auto result = id_map_.find(tf_gpu_id.value()); CHECK(result != id_map_.end()) << "Could not find the mapping for TfGpuId: " << tf_gpu_id; - return result->second; + return CudaGpuId(result->second); + } + + void TestOnlyReset() LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + id_map_.clear(); } - private: using IdMapType = std::unordered_map; mutable mutex mu_; IdMapType id_map_ GUARDED_BY(mu_); + + friend class ::tensorflow::GpuIdManager; + TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap); }; } // namespace @@ -67,8 +93,20 @@ void GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, TfToCudaGpuIdMap::singleton()->InsertOrDie(tf_gpu_id, cuda_gpu_id); } +Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) { + if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) { + return Status::OK(); + } + return errors::NotFound("TF GPU device with id ", tf_gpu_id.value(), + " was not registered"); +} + CudaGpuId GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id) { - return CudaGpuId(TfToCudaGpuIdMap::singleton()->FindOrDie(tf_gpu_id)); + return TfToCudaGpuIdMap::singleton()->FindOrDie(tf_gpu_id); +} + +void GpuIdManager::TestOnlyReset() { + TfToCudaGpuIdMap::singleton()->TestOnlyReset(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h index 33925d8c36f44a9d2c7abc8f2801f3f203bcb982..2b54cc184ca508b94e2a715642cdb13fe8a4c3e1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h @@ -17,15 +17,25 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_ID_MANAGER_H_ #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/lib/core/status.h" namespace tensorflow { -// Class that manages the translation between Tensorflow GPU ids and CUDA GPU -// ids. +// Class that maintains a map from TfGpuId to CudaGpuId, and manages the +// translation between them. class GpuIdManager { public: + // Adds a mapping from tf_gpu_id to cuda_gpu_id. static void InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id); + + // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found. + static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id); + // Similar to the above version, but returns the result, and checks fail if + // no result is found. static CudaGpuId TfToCudaGpuId(TfGpuId tf_gpu_id); + + // Clears the map. Used in unit tests only. + static void TestOnlyReset(); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h index 91ce830df8521e7fe8284dd3c52d1bbf667891cd..38d669ea07c91bc1a892ecf925b3141f2ca506dd 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator.h +++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h @@ -24,7 +24,7 @@ limitations under the License. #include #include #include -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/visitable_allocator.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc index 61013bd1acd254b6e927a8d41accaeda424d6ebc..866a03d04632c649fae278c4ab311e22ebf8dc31 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.cc +++ b/tensorflow/core/common_runtime/gpu/process_state.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -318,4 +319,17 @@ void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) { #endif // GOOGLE_CUDA } +void ProcessState::TestOnlyReset() { + mutex_lock lock(mu_); + gpu_device_enabled_ = false; + gpu_visitors_.clear(); + mem_desc_map_.clear(); + gtl::STLDeleteElements(&cpu_allocators_); + gtl::STLDeleteElements(&gpu_allocators_); + gtl::STLDeleteElements(&cuda_host_allocators_); + gtl::STLDeleteElements(&cpu_al_); + gtl::STLDeleteElements(&gpu_al_); + gtl::STLDeleteElements(&cuda_al_); +} + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/process_state.h b/tensorflow/core/common_runtime/gpu/process_state.h index f6e234967306476542cec3038ea2e271cca2dc8c..bc2c4182d72334e26d387397e564dbf02cfa3ae4 100644 --- a/tensorflow/core/common_runtime/gpu/process_state.h +++ b/tensorflow/core/common_runtime/gpu/process_state.h @@ -114,6 +114,10 @@ class ProcessState { protected: ProcessState(); + // Helper method for unit tests to reset the ProcessState singleton by + // cleaning up everything. Never use in production. + virtual void TestOnlyReset(); + static ProcessState* instance_; bool gpu_device_enabled_; @@ -132,6 +136,8 @@ class ProcessState { std::vector cpu_al_ GUARDED_BY(mu_); std::vector gpu_al_ GUARDED_BY(mu_); std::vector cuda_al_ GUARDED_BY(mu_); + + friend class GPUDeviceTest; }; namespace internal { diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index a21304f7ef843706d564bd3f3a511324fd3189d6..f1082a60030fb3c289de35b4cab397c527f8afca 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -156,21 +156,21 @@ Status GraphRunner::Run(Graph* graph, FunctionLibraryRuntime* function_library, // should not be running expensive operators. auto runner = [](Executor::Args::Closure c) { c(); }; - // Take ownership and pass to NewLocalExecutor - Graph* g = graph_to_run.release(); - LocalExecutorParams params; // The ownership of the output tensors are bound to this device's lifetime. params.device = cpu_device_.get(); params.function_library = function_library; - params.create_kernel = [this, g](const NodeDef& ndef, OpKernel** kernel) { - return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, - g->versions().producer(), kernel); + const int producer = graph_to_run->versions().producer(); + params.create_kernel = [this, producer](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(cpu_device_.get(), nullptr, ndef, producer, + kernel); }; params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; Executor* executor; - TF_RETURN_IF_ERROR(NewLocalExecutor(params, g, &executor)); + TF_RETURN_IF_ERROR( + NewLocalExecutor(params, std::move(graph_to_run), &executor)); std::unique_ptr executor_unref(executor); Executor::Args args; diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc index 420dfe338efb473e36eb02a757fa957d15ba64df..64d884947568381eb2e5f60ab181b3c8c709d53b 100644 --- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc +++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc @@ -39,6 +39,7 @@ limitations under the License. namespace tensorflow { namespace test { +// TODO(hongm): Convert `g` and `init` to using std::unique_ptr. Benchmark::Benchmark(const string& device, Graph* g, const SessionOptions* options, Graph* init, Rendezvous* rendez) { @@ -85,7 +86,8 @@ Benchmark::Benchmark(const string& device, Graph* g, if (init) { Executor* init_exec; - TF_CHECK_OK(NewLocalExecutor(params, init, &init_exec)); + TF_CHECK_OK( + NewLocalExecutor(params, std::unique_ptr(init), &init_exec)); Executor::Args args; args.rendezvous = rendez_; args.runner = runner; @@ -93,7 +95,7 @@ Benchmark::Benchmark(const string& device, Graph* g, delete init_exec; } - TF_CHECK_OK(NewLocalExecutor(params, g, &exec_)); + TF_CHECK_OK(NewLocalExecutor(params, std::unique_ptr(g), &exec_)); } Benchmark::~Benchmark() { diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 0eb47f4e56dafff93736bd8c7112098fd11c0fed..fb092424bfc0b1bd8653e630b246b2749eb665fd 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -21,11 +21,10 @@ limitations under the License. #ifdef INTEL_MKL -#include #include #include #include "tensorflow/core/common_runtime/bfc_allocator.h" -#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/visitable_allocator.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" @@ -161,7 +160,7 @@ class MklCPUAllocator : public VisitableAllocator { /// The alignment that we need for the allocations static const size_t kAlignment = 64; - Allocator* allocator_; // owned by this class + VisitableAllocator* allocator_; // owned by this class }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index a913f2075181a3896015579d79093395d67101ff..e128b9257f2369e25c911f9a9e1d08475706d561 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -464,6 +464,7 @@ class ColocationGraph { // the user can see why an unsatisfiable placement occurred. std::unordered_map type_to_devices; + std::vector colocation_nodes; int num_nodes_found = 0; for (const Node* node : graph_->nodes()) { @@ -475,6 +476,7 @@ class ColocationGraph { continue; } ++num_nodes_found; + colocation_nodes.push_back(node); const string& op_type = node->type_string(); string devices_registered; for (const auto& device_type : members_[id].supported_device_types) { @@ -488,6 +490,13 @@ class ColocationGraph { for (const auto& td : type_to_devices) { strings::StrAppend(&text, "\n", td.first, ": ", td.second); } + strings::StrAppend(&text, + "\n\nColocation members and user-requested devices:"); + for (const Node* node : colocation_nodes) { + strings::StrAppend(&text, "\n ", node->name(), " (", node->type_string(), + ") ", node->requested_device()); + } + strings::StrAppend(&text, "\n"); if (num_nodes_found <= 1) { text.clear(); diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc index f9d9633beea1c59dc79880b2120332f3ee7588bd..e205e34aa0f6afb1363d65bd23403d4b50f056eb 100644 --- a/tensorflow/core/common_runtime/process_function_library_runtime.cc +++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc @@ -246,7 +246,7 @@ Status ProcessFunctionLibraryRuntime::ReleaseHandle( string target_device; { mutex_lock l(mu_); - CHECK_EQ(1, function_data_.count(handle)); + CHECK_EQ(1, function_data_.count(handle)) << " handle: " << handle; target_device = function_data_[handle].target_device; } flr = GetFLR(target_device); diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index 45cdab98e0642a3fbfee3dfa415696b98251600a..2acaa31d32de40148bd88021eb0613f0fb8522ff 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -211,14 +211,14 @@ Status ShapeRefiner::AddNode(const Node* node) { // For each 'input' of this node, fetch the corresponding shape // from 'input's InferenceContext, and store into a vector // indexed by 'node's input. - std::vector input_nodes(node->num_inputs()); + std::vector input_nodes(node->num_inputs()); std::vector input_shapes(node->num_inputs()); std::vector>> input_handle_shapes_and_types(node->num_inputs()); for (const Edge* e : node->in_edges()) { if (e->IsControlEdge()) continue; - Node* input = e->src(); + const Node* input = e->src(); auto it = node_to_context_.find(input); if (it == node_to_context_.end()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD index a32badef6dfdb8b62662da880c99842b1cafd13c..40cb8353cdccb4307f09b537ff7016e3dca5a8da 100644 --- a/tensorflow/core/debug/BUILD +++ b/tensorflow/core/debug/BUILD @@ -196,7 +196,10 @@ tf_cc_test( srcs = ["debug_gateway_test.cc"], args = ["--heap_check=local"], linkstatic = tf_kernel_tests_linkstatic(), - tags = ["no_gpu"], + tags = [ + "no_cuda_on_cpu_tap", + "no_gpu", + ], deps = [ ":debug", ":debug_gateway_internal", diff --git a/tensorflow/core/distributed_runtime/executor_test.cc b/tensorflow/core/distributed_runtime/executor_test.cc index 5b115f9a4d4ea3e9b99228918e16fc354d5a99fe..e34224205bac48a2dba1bf8cb07f9c623cd38281 100644 --- a/tensorflow/core/distributed_runtime/executor_test.cc +++ b/tensorflow/core/distributed_runtime/executor_test.cc @@ -57,7 +57,7 @@ class ExecutorTest : public ::testing::Test { } // Resets executor_ with a new executor based on a graph 'gdef'. - void Create(const Graph* graph) { + void Create(std::unique_ptr graph) { const int version = graph->versions().producer(); LocalExecutorParams params; params.device = device_; @@ -69,7 +69,7 @@ class ExecutorTest : public ::testing::Test { DeleteNonCachedKernel(kernel); }; delete exec_; - TF_CHECK_OK(NewLocalExecutor(params, graph, &exec_)); + TF_CHECK_OK(NewLocalExecutor(params, std::move(graph), &exec_)); runner_ = [this](std::function fn) { thread_pool_->Schedule(fn); }; rendez_ = NewLocalRendezvous(); } @@ -144,12 +144,12 @@ Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, TEST_F(ExecutorTest, SimpleAdd) { // c = a + b - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB); - auto tmp = test::graph::Add(g, in0, in1); - test::graph::Send(g, tmp, "c", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); // in0 = 1.0 @@ -172,15 +172,15 @@ TEST_F(ExecutorTest, SelfAdd) { // // b <- v10 // All nodes are executed by one thread. - Graph* g = new Graph(OpRegistry::Global()); - auto v = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto v = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); const int N = 10; for (int i = 1; i <= N; ++i) { - v = test::graph::Add(g, v, v); + v = test::graph::Add(g.get(), v, v); } // out <- v10 - test::graph::Send(g, v, "b", BOB, 1, ALICE); - Create(g); + test::graph::Send(g.get(), v, "b", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; // a = 1.0 TF_ASSERT_OK( @@ -229,9 +229,9 @@ void BuildTree(int N, Graph* g) { } TEST_F(ExecutorTest, RandomTree) { - Graph* g = new Graph(OpRegistry::Global()); - BuildTree(4096, g); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK( rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); @@ -262,9 +262,9 @@ void BuildConcurrentAddAssign(Graph* g) { #ifndef THREAD_SANITIZER TEST_F(ExecutorTest, ConcurrentAddAssign) { - Graph* g = new Graph(OpRegistry::Global()); - BuildConcurrentAddAssign(g); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + BuildConcurrentAddAssign(g.get()); + Create(std::move(g)); for (int iters = 0; iters < 16; ++iters) { Rendezvous* rendez = NewLocalRendezvous(); TF_ASSERT_OK(Run(rendez)); @@ -281,12 +281,12 @@ TEST_F(ExecutorTest, ConcurrentAddAssign) { #endif TEST_F(ExecutorTest, SimpleSwitchLive) { - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Constant(g, VB(false)); - auto tmp = test::graph::Switch(g, in0, in1); - test::graph::Send(g, tmp, "c", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g.get(), VB(false)); + auto tmp = test::graph::Switch(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); // in0 = 1.0 @@ -300,12 +300,12 @@ TEST_F(ExecutorTest, SimpleSwitchLive) { } TEST_F(ExecutorTest, SimpleSwitchDead) { - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Constant(g, VB(true)); - auto tmp = test::graph::Switch(g, in0, in1); - test::graph::Send(g, tmp, "c", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Constant(g.get(), VB(true)); + auto tmp = test::graph::Switch(g.get(), in0, in1); + test::graph::Send(g.get(), tmp, "c", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous::Args args; TF_ASSERT_OK(rendez_->Send(Key(ALICE, kIncarnation, BOB, "a"), args, V(1.0), false)); // in0 = 1.0 @@ -319,16 +319,16 @@ TEST_F(ExecutorTest, SimpleSwitchDead) { TEST_F(ExecutorTest, Abort) { // e = a + b + c + d - Graph* g = new Graph(OpRegistry::Global()); - auto in0 = test::graph::Recv(g, "a", "float", ALICE, 1, BOB); - auto in1 = test::graph::Recv(g, "b", "float", ALICE, 1, BOB); - auto in2 = test::graph::Recv(g, "c", "float", ALICE, 1, BOB); - auto in3 = test::graph::Recv(g, "d", "float", ALICE, 1, BOB); - auto add0 = test::graph::Add(g, in0, in1); - auto add1 = test::graph::Add(g, in2, in3); - auto add2 = test::graph::Add(g, add0, add1); - test::graph::Send(g, add2, "e", BOB, 1, ALICE); - Create(g); + std::unique_ptr g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Recv(g.get(), "a", "float", ALICE, 1, BOB); + auto in1 = test::graph::Recv(g.get(), "b", "float", ALICE, 1, BOB); + auto in2 = test::graph::Recv(g.get(), "c", "float", ALICE, 1, BOB); + auto in3 = test::graph::Recv(g.get(), "d", "float", ALICE, 1, BOB); + auto add0 = test::graph::Add(g.get(), in0, in1); + auto add1 = test::graph::Add(g.get(), in2, in3); + auto add2 = test::graph::Add(g.get(), add0, add1); + test::graph::Send(g.get(), add2, "e", BOB, 1, ALICE); + Create(std::move(g)); // Needs 4 inputs (recv). One of them is aborted. rendez_->Ref(); @@ -371,17 +371,17 @@ TEST_F(ExecutorTest, Abort) { } TEST_F(ExecutorTest, RecvInvalidDtype) { - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr g(new Graph(OpRegistry::Global())); // An input vector of type float of size 1. - auto one = test::graph::Recv(g, "one", "float", ALICE, 1, BOB); + auto one = test::graph::Recv(g.get(), "one", "float", ALICE, 1, BOB); // A floating point variable vector of size 1. - auto var = test::graph::Var(g, DT_FLOAT, TensorShape({1})); + auto var = test::graph::Var(g.get(), DT_FLOAT, TensorShape({1})); // Initialize the variable with input. - auto init = test::graph::Assign(g, var, one); + auto init = test::graph::Assign(g.get(), var, one); // Output - auto* two = test::graph::Send(g, var, "two", BOB, 1, ALICE); + auto* two = test::graph::Send(g.get(), var, "two", BOB, 1, ALICE); g->AddControlEdge(init, two); // Ensures run after init. - Create(g); + Create(std::move(g)); Rendezvous* rendez = NewLocalRendezvous(); // Send a double instead of float. TF_ASSERT_OK(rendez->Send(Key(ALICE, 1, BOB, "one"), Rendezvous::Args(), @@ -396,11 +396,11 @@ TEST_F(ExecutorTest, RecvInvalidDtype) { } TEST_F(ExecutorTest, RecvInvalidRefDtype) { - Graph* g = new Graph(OpRegistry::Global()); + std::unique_ptr g(new Graph(OpRegistry::Global())); // A var that always produces as invalid dtype. - auto var = test::graph::InvalidRefType(g, DT_FLOAT, DT_DOUBLE); - test::graph::Send(g, var, "out", BOB, 1, ALICE); - Create(g); + auto var = test::graph::InvalidRefType(g.get(), DT_FLOAT, DT_DOUBLE); + test::graph::Send(g.get(), var, "out", BOB, 1, ALICE); + Create(std::move(g)); Rendezvous* rendez = NewLocalRendezvous(); EXPECT_TRUE(errors::IsInternal(Run(rendez))); Tensor output; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 0120f612ac8bee32999304b1a6f63fff3802606a..7878ebb5f06db0f64e9216250da2a79352274ab3 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -271,7 +271,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, skip_cost_models_ = false; } TF_RETURN_IF_ERROR( - NewLocalExecutor(params, subgraph.release(), &unit->root)); + NewLocalExecutor(params, std::move(subgraph), &unit->root)); } return Status::OK(); } diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc index c4ac92d809627e7134b5d4ae694f9978cd5390b4..a6f4be3eaf69f40199e64c43dff443e886aa5aa1 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc @@ -106,7 +106,8 @@ GrpcServer::~GrpcServer() { Status GrpcServer::Init( ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, - const WorkerCreationFunction& worker_func) { + const WorkerCreationFunction& worker_func, + const StatsPublisherFactory& stats_factory) { mutex_lock l(mu_); CHECK_EQ(state_, NEW); master_env_.env = env_; @@ -218,7 +219,7 @@ Status GrpcServer::Init( master_env_.ops = OpRegistry::Global(); master_env_.worker_cache = worker_cache; master_env_.master_session_factory = - [config]( + [config, stats_factory]( SessionOptions options, const MasterEnv* env, std::unique_ptr>> remote_devs, std::unique_ptr worker_cache, @@ -226,7 +227,7 @@ Status GrpcServer::Init( options.config.MergeFrom(config); return new MasterSession(options, env, std::move(remote_devs), std::move(worker_cache), std::move(device_set), - CreateNoOpStatsPublisher); + stats_factory); }; master_env_.worker_cache_factory = [this](const WorkerCacheFactoryOptions& options, @@ -241,6 +242,14 @@ Status GrpcServer::Init( return Status::OK(); } +Status GrpcServer::Init( + ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const WorkerCreationFunction& worker_func) { + return Init(std::move(service_func), rendezvous_mgr_func, worker_func, + CreateNoOpStatsPublisher); +} + Status GrpcServer::Init( ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func) { diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h index 8b12ac1461d6b1fa3098197aa7697031a5d3075b..7c2f06f618a85c901ce7a7902cb8b1bc4e57be40 100644 --- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h +++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h @@ -22,6 +22,7 @@ limitations under the License. #include "grpc++/security/credentials.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/common_runtime/stats_publisher_interface.h" #include "tensorflow/core/distributed_runtime/master_env.h" #include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h" @@ -68,6 +69,11 @@ class GrpcServer : public ServerInterface { const string target() const override; protected: + Status Init(ServiceInitFunction service_func, + const RendezvousMgrCreationFunction& rendezvous_mgr_func, + const WorkerCreationFunction& worker_func, + const StatsPublisherFactory& stats_factory); + Status Init(ServiceInitFunction service_func, const RendezvousMgrCreationFunction& rendezvous_mgr_func, const WorkerCreationFunction& worker_func); diff --git a/tensorflow/core/distributed_runtime/scheduler.cc b/tensorflow/core/distributed_runtime/scheduler.cc index 9dae5b3b926fab14c2b36955436d3956baa29fdd..84036361971b73f9fb7fe990833d5018f6321e27 100644 --- a/tensorflow/core/distributed_runtime/scheduler.cc +++ b/tensorflow/core/distributed_runtime/scheduler.cc @@ -80,7 +80,7 @@ Microseconds SlackAnalysis::ComputeAsap(std::vector* asap_times) { std::vector pending_count(graph_->num_node_ids()); InitializePending(graph_, &pending_count); - std::deque queue; + std::deque queue; Node* srcNode = graph_->source_node(); queue.push_back(srcNode); (*asap_times)[srcNode->id()] = 0; @@ -92,7 +92,7 @@ Microseconds SlackAnalysis::ComputeAsap(std::vector* asap_times) { for (const Edge* out_edge : curr->out_edges()) { // The time needed for 'out' to get its input from 'curr'. Microseconds copy_time(0); - Node* out = out_edge->dst(); + const Node* out = out_edge->dst(); if (!out_edge->IsControlEdge() && curr->assigned_device_name() != out->assigned_device_name()) { // Add an arbitrary 10microsecs for each copy. @@ -137,7 +137,7 @@ Microseconds SlackAnalysis::ComputeAlap(std::vector* alap_times) { } } - std::deque queue; + std::deque queue; Node* sinkNode = graph_->sink_node(); queue.push_back(sinkNode); (*alap_times)[sinkNode->id()] = 0; @@ -148,7 +148,7 @@ Microseconds SlackAnalysis::ComputeAlap(std::vector* alap_times) { for (const Edge* in_edge : curr->in_edges()) { // The time needed for 'curr' to get its input from 'src'. Microseconds copy_time(0); - Node* src = in_edge->src(); + const Node* src = in_edge->src(); if (!in_edge->IsControlEdge() && src->assigned_device_name() != curr->assigned_device_name()) { // TODO(yuanbyu): Use the real cost model @@ -236,7 +236,7 @@ Microseconds GreedyScheduler::ComputeSchedule( for (const Edge* out_edge : event.node->out_edges()) { Microseconds copy_time(0); - Node* out = out_edge->dst(); + const Node* out = out_edge->dst(); if (!out_edge->IsControlEdge() && event.node->assigned_device_name() != out->assigned_device_name()) { // TODO(yuanbyu): Use below with the real cost model. @@ -277,11 +277,11 @@ Microseconds GreedyScheduler::ComputeSchedule( return max_completion; } -Node* GreedyScheduler::GetNodeWithHighestPriority( - const std::vector& nodes) { - Node* curr_node = nullptr; +const Node* GreedyScheduler::GetNodeWithHighestPriority( + const std::vector& nodes) { + const Node* curr_node = nullptr; int64 curr_priority = kint64max; - for (Node* n : nodes) { + for (const Node* n : nodes) { if ((*priority_)[n->id()] < curr_priority) { curr_node = n; curr_priority = (*priority_)[n->id()]; diff --git a/tensorflow/core/distributed_runtime/scheduler.h b/tensorflow/core/distributed_runtime/scheduler.h index ef87b9834dba50cf628a8c29c70b0266661d6227..bf9d0d1bec33284a44f69412477edb4a0963e8a1 100644 --- a/tensorflow/core/distributed_runtime/scheduler.h +++ b/tensorflow/core/distributed_runtime/scheduler.h @@ -57,11 +57,11 @@ class GreedyScheduler { struct Sim { int degree_parallelism; int num_running; - std::vector ready_nodes; + std::vector ready_nodes; }; struct Event { - Node* node; + const Node* node; Microseconds time; bool is_completion; @@ -79,7 +79,7 @@ class GreedyScheduler { private: // Returns the ready node with the highest priority for a sim. - Node* GetNodeWithHighestPriority(const std::vector& nodes); + const Node* GetNodeWithHighestPriority(const std::vector& nodes); const DeviceSet* devices_; const CostModel* cost_model_; diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index 4e9352ee32227376957157c7ada63390689ac39a..d977935b8a392adf1f78c38955f77f6f364502c9 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -56,9 +56,9 @@ limitations under the License. // // To add values to feature_lists: // AppendFeatureValues({4.0}, -// GetFeatureList("movie_ratings", &se)->Add()); +// GetFeatureList("images", &se)->Add()); // AppendFeatureValues({5.0, 3.0}, -// GetFeatureList("movie_ratings", &se)->Add()); +// GetFeatureList("images", &se)->Add()); // This will create a feature list keyed as "images" with two features: // feature_lists { // feature_list { diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 94bf34afa49f586e1bb61c1654865a5abc9abe19..a382b8be95f143898a8f52f887b9396f3823372b 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/visitable_allocator.h" #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/framework/log_memory.h" @@ -68,15 +68,19 @@ void EnableCPUAllocatorFullStats(bool enable) { cpu_allocator_collect_full_stats = enable; } -class CPUAllocator : public Allocator { +class CPUAllocator : public VisitableAllocator { public: - CPUAllocator() {} + CPUAllocator() : allocation_begun_(false) {} ~CPUAllocator() override {} string Name() override { return "cpu"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override { + if (!allocation_begun_) { + allocation_begun_ = true; + } + void* p = port::AlignedMalloc(num_bytes, alignment); if (cpu_allocator_collect_stats) { const std::size_t alloc_size = port::MallocExtension_GetAllocatedSize(p); @@ -88,16 +92,38 @@ class CPUAllocator : public Allocator { stats_.max_alloc_size = std::max(stats_.max_alloc_size, alloc_size); } + + // visit each Visitor in alloc_visitors_ + if (p != nullptr) { + for (const Visitor& v : alloc_visitors_) { + v(p, num_bytes); + } + } + return p; } void DeallocateRaw(void* ptr) override { + std::size_t alloc_size; + bool init_alloc_size = false; if (cpu_allocator_collect_stats) { - const std::size_t alloc_size = - port::MallocExtension_GetAllocatedSize(ptr); + alloc_size = port::MallocExtension_GetAllocatedSize(ptr); + init_alloc_size = true; mutex_lock l(mu_); stats_.bytes_in_use -= alloc_size; } + + // visit each Visitor in free_visitors_ + if (ptr != nullptr) { + if (!init_alloc_size) { + alloc_size = port::MallocExtension_GetAllocatedSize(ptr); + init_alloc_size = true; + } + for (const Visitor& v : free_visitors_) { + v(ptr, alloc_size); + } + } + port::AlignedFree(ptr); } @@ -117,10 +143,36 @@ class CPUAllocator : public Allocator { return port::MallocExtension_GetAllocatedSize(ptr); } + // REQUIRES: can only add visitors before the first Allocate call + + void AddAllocVisitor(Visitor visitor) override { + mutex_lock lock(visitor_mutex_); + CHECK(!allocation_begun_) + << "AddAllocVisitor may not be called after allocation has begun."; + alloc_visitors_.push_back(visitor); + } + + void AddFreeVisitor(Visitor visitor) override { + mutex_lock lock(visitor_mutex_); + CHECK(!allocation_begun_) + << "AddFreeVisitor may not be called after allocation has begun."; + free_visitors_.push_back(visitor); + } + private: mutex mu_; AllocatorStats stats_ GUARDED_BY(mu_); + // visitor_mutex_ protects write access to alloc_visitors_ and free_visitors_. + // While write access is mutually exclusive, reads may happen concurrently. + // This is okay because we may only append to alloc_visitors_ and + // free_visitors_ before first allocation, and subsequently we only read these + // vectors. + mutex visitor_mutex_; + std::vector alloc_visitors_; + std::vector free_visitors_; + std::atomic allocation_begun_; + TF_DISALLOW_COPY_AND_ASSIGN(CPUAllocator); }; diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 1838a8ad02d2bd5522ce3162fea53e3f5afc0309..fb6d5c69e135c0263845cf71b93ac53bb2a359ed 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -128,6 +128,8 @@ class DeviceBase { // using a single stream.) // "event_mgr" is used to delay deallocation of temporary GPU buffers. // TODO(pbar) Work out how to move this out of DeviceBase. + // GpuDeviceInfo name is an unfortunate legacy, it is used not only by GPUs + // but also by TPU devices (to provide default device context). struct GpuDeviceInfo { // Make sure all the defaults are NULL, so we can spot missing assignments. perftools::gputools::Stream* stream = nullptr; @@ -230,6 +232,7 @@ class DeviceBase { private: Env* const env_; CpuWorkerThreads* cpu_worker_threads_ = nullptr; + // Set by GPUs as well as by TPU devices. GpuDeviceInfo* gpu_device_info_ = nullptr; thread::ThreadPool* device_thread_pool_ = nullptr; Eigen::ThreadPoolDevice* eigen_cpu_device_ = nullptr; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index eae8e6c3c10c4b49081aed0e253d9a6f382f562b..3e7b89d4ebc91df42ee81c1c9fe67c68e755f736 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -168,7 +168,7 @@ class FunctionInstantiationHelper { strings::StrAppend(&name, "_", i); } NodeDef* gnode = AddNode(name); - gnode->set_op("_Arg"); + gnode->set_op(FunctionLibraryDefinition::kArgOp); AddAttr("T", dtypes[i], gnode); AddAttr("index", arg_index, gnode); result_.arg_types.push_back(dtypes[i]); @@ -328,7 +328,7 @@ class FunctionInstantiationHelper { strings::StrAppend(&name, "_", i); } NodeDef* gnode = AddNode(name); - gnode->set_op("_Retval"); + gnode->set_op(FunctionLibraryDefinition::kRetOp); AddInput(nodes_.size() - 1, item->nid, item->idx + i); AddAttr("T", dtypes[i], gnode); AddAttr("index", (*ret_index)++, gnode); @@ -558,9 +558,9 @@ string Print(gtl::ArraySlice nodes) { std::vector ret; std::vector body; for (const NodeDef* n : nodes) { - if (n->op() == "_Arg") { + if (n->op() == FunctionLibraryDefinition::kArgOp) { arg.push_back(n); - } else if (n->op() == "_Retval") { + } else if (n->op() == FunctionLibraryDefinition::kRetOp) { ret.push_back(n); } else { body.push_back(n); diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e27001133bbb5056abf1a3e1f5b9d69c8e01bc56..e00399f97de42ca6c683202fdec9142310fa6e2d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -344,6 +344,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface { Status LookUp(const string& op_type_name, const OpRegistrationData** op_reg_data) const override; + // Ops created for function arguments bear the name given by `kArgOp`; those + // created for return values bear the name given by `kRetOp`. + static constexpr const char* const kArgOp = "_Arg"; + static constexpr const char* const kRetOp = "_Retval"; + static constexpr const char* const kGradientOp = "SymbolicGradient"; static constexpr const char* const kFuncAttr = "f"; @@ -404,6 +409,8 @@ struct FunctionBody; // Forward declare. Defined in common_runtime/device.h class Device; +// Forward declare. Defined in common_runtime/device_mgr.h +class DeviceMgr; class FunctionLibraryRuntime { public: @@ -518,6 +525,9 @@ class FunctionLibraryRuntime { // Returns the device on which the function executes. virtual Device* device() = 0; + // Get the DeviceMgr from which the device was obtained. + virtual const DeviceMgr* device_mgr() const = 0; + // Returns the function library definition that backs this runtime. // NOTE(mrry): The returned library definition is the default function library // for this runtime. The runtime may instantiate functions from separate diff --git a/tensorflow/core/framework/op.cc b/tensorflow/core/framework/op.cc index fadb60d744217daa0c569601c437146a70f9b4d5..fc5467b3c86934908c3f1261c79659c6a0469350 100644 --- a/tensorflow/core/framework/op.cc +++ b/tensorflow/core/framework/op.cc @@ -110,6 +110,15 @@ void OpRegistry::GetRegisteredOps(std::vector* op_defs) { } } +void OpRegistry::GetOpRegistrationData( + std::vector* op_data) { + mutex_lock lock(mu_); + MustCallDeferred(); + for (const auto& p : registry_) { + op_data->push_back(*p.second); + } +} + Status OpRegistry::SetWatcher(const Watcher& watcher) { mutex_lock lock(mu_); if (watcher_ && watcher) { diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index f7f1ed2a886548c39fa38239d65aa2a73564c3c4..3ccca4090d9804050c484d64a62826665b94d4d2 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -89,6 +89,9 @@ class OpRegistry : public OpRegistryInterface { // Get all registered ops. void GetRegisteredOps(std::vector* op_defs); + // Get all `OpRegistrationData`s. + void GetOpRegistrationData(std::vector* op_data); + // Watcher, a function object. // The watcher, if set by SetWatcher(), is called every time an op is // registered via the Register function. The watcher is passed the Status diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 0645ec42822fe7633e0517b28e50b0c221b3f80e..5d32b71628263fe89d6f54fd07b2fe18bbb55e53 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -1025,9 +1025,8 @@ StringPiece Tensor::tensor_data() const { } bool Tensor::SharesBufferWith(const Tensor& b) const { - CHECK_NE(nullptr, buf_); - CHECK_NE(nullptr, b.buf_); - return buf_->root_buffer() == b.buf_->root_buffer(); + return buf_ != nullptr && b.buf_ != nullptr && + buf_->root_buffer() == b.buf_->root_buffer(); } string Tensor::DebugString() const { diff --git a/tensorflow/core/framework/tensor_shape.h b/tensorflow/core/framework/tensor_shape.h index adb41b81c6ec019ce51a3871ca329c82f8a1f4b7..fe2ba375aa0c5c50009b3155338cd8860070d47a 100644 --- a/tensorflow/core/framework/tensor_shape.h +++ b/tensorflow/core/framework/tensor_shape.h @@ -191,9 +191,6 @@ class TensorShapeBase : public TensorShapeRep { /// Appends all the dimensions from `shape`. void AppendShape(const TensorShapeBase& shape); - // Maximum number of dimensions in a tensor. - static constexpr int MaxDimensions() { return 254; } - /// \brief Insert a dimension somewhere in the `TensorShape`. /// REQUIRES: `0 <= d <= dims()` /// REQUIRES: `size >= 0` diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 81644388abcf9c14bc5812069f25906a7f72b4cc..b613effd18bbbaf107a56b518859024db1c9bbb2 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -1085,6 +1085,21 @@ class DummyCPUAllocator : public Allocator { void DeallocateRaw(void* ptr) override {} }; +TEST(Tensor, SharesBufferWith) { + Tensor a_empty; + Tensor b_empty; + Tensor a(DT_FLOAT, TensorShape({1})); + Tensor b(DT_FLOAT, TensorShape({1})); + Tensor copy(a); + EXPECT_FALSE(a_empty.SharesBufferWith(a_empty)); + EXPECT_FALSE(a_empty.SharesBufferWith(b_empty)); + EXPECT_FALSE(a_empty.SharesBufferWith(a)); + EXPECT_FALSE(a_empty.SharesBufferWith(copy)); + EXPECT_TRUE(a.SharesBufferWith(a)); + EXPECT_FALSE(a.SharesBufferWith(b)); + EXPECT_TRUE(a.SharesBufferWith(copy)); +} + TEST(Tensor, FailureToAllocate) { TensorShape shape({1}); DummyCPUAllocator allocator; diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/framework/visitable_allocator.h similarity index 94% rename from tensorflow/core/common_runtime/visitable_allocator.h rename to tensorflow/core/framework/visitable_allocator.h index 8edf922d11ee1662b78771bfdc7c38e0144aee19..ed41b05531acaa1be803ac533854efe6160691b4 100644 --- a/tensorflow/core/common_runtime/visitable_allocator.h +++ b/tensorflow/core/framework/visitable_allocator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ -#define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_ +#define TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_ #include #include "tensorflow/core/framework/allocator.h" @@ -76,4 +76,4 @@ class TrackingVisitableAllocator : public TrackingAllocator, VisitableAllocator* allocator_; }; } // namespace tensorflow -#endif // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_VISITABLE_ALLOCATOR_H_ diff --git a/tensorflow/core/graph/control_flow.cc b/tensorflow/core/graph/control_flow.cc index db6683d1e74512e37a40773b7642cf33eb888782..30ff19cd7eae794e0e9875ca0825b647b44b02af 100644 --- a/tensorflow/core/graph/control_flow.cc +++ b/tensorflow/core/graph/control_flow.cc @@ -24,23 +24,24 @@ limitations under the License. namespace tensorflow { -Status BuildControlFlowInfo(Graph* g, std::vector* info) { +Status BuildControlFlowInfo(const Graph* g, + std::vector* info) { info->clear(); info->resize(g->num_node_ids()); std::vector parent_nodes; parent_nodes.resize(g->num_node_ids()); - Node* src_node = g->source_node(); + const Node* src_node = g->source_node(); ControlFlowInfo& src_info = (*info)[src_node->id()]; src_info.frame = src_node; src_info.parent_frame = src_node; string frame_name; - std::deque ready; + std::deque ready; ready.push_back(src_node); while (!ready.empty()) { - Node* curr_node = ready.front(); + const Node* curr_node = ready.front(); ready.pop_front(); const ControlFlowInfo& curr_info = (*info)[curr_node->id()]; const Node* frame = curr_info.frame; @@ -56,7 +57,7 @@ Status BuildControlFlowInfo(Graph* g, std::vector* info) { } for (const Edge* out_edge : curr_node->out_edges()) { - Node* out = out_edge->dst(); + const Node* out = out_edge->dst(); int out_id = out->id(); ControlFlowInfo* out_info = &(*info)[out_id]; const Node* out_parent = out_info->parent_frame; diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h index 372044f538f9428e1979ba80bbb18a9742fc014e..79e2be0d4b9db6dd70d339ee07faf25c85376386 100644 --- a/tensorflow/core/graph/control_flow.h +++ b/tensorflow/core/graph/control_flow.h @@ -30,14 +30,14 @@ struct ControlFlowInfo { string frame_name; // frame name of a node }; -// Assign to each node the name of the frame and the level it belongs to. -// We check the well-formedness of the graph: All inputs to a node must -// come from the same frame and have the same "static" iteration level. -// `info` is cleared and populated by this function. -// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level -// 0. This essentially means there can't be multiple serial Nexts in -// an iteration, which all sane front-ends should satisfy. -Status BuildControlFlowInfo(Graph* g, std::vector* info); +// Clear and populate `info` with each node's frame and the level it belongs to. +// We check the well-formedness of the graph: All inputs to a node must come +// from the same frame and have the same "static" iteration level. +// +// NOTE(yuanbyu): For now, we require all sends/recvs have iteration level 0. +// This essentially means there can't be multiple serial Nexts in an iteration, +// which all sane front-ends should satisfy. +Status BuildControlFlowInfo(const Graph* g, std::vector* info); } // namespace tensorflow diff --git a/tensorflow/core/graph/costmodel.cc b/tensorflow/core/graph/costmodel.cc index 4f3a6ec38cb88213c7127df41823bc16e9834d09..1df45d9b893fdb2807c5e6ab63dd4a8577d7feb6 100644 --- a/tensorflow/core/graph/costmodel.cc +++ b/tensorflow/core/graph/costmodel.cc @@ -427,7 +427,7 @@ static void AssignSizes(const Graph& g, CostModel* cost_model) { if (e->IsControlEdge()) { continue; } - Node* src = e->src(); + const Node* src = e->src(); // TODO(josh11b): Get an estimate from the Op Bytes size(1); diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 9b56216f1f97a9598dd7ae8b70786e32bb7e0f4b..a7af5e2312af716ef25cb35c8f247d6feccb6d9c 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -339,7 +339,7 @@ Node* Graph::AddNode(const NodeDef& node_def, Status* status) { return node; } -Node* Graph::CopyNode(Node* node) { +Node* Graph::CopyNode(const Node* node) { DCHECK(!node->IsSource()); DCHECK(!node->IsSink()); Node* copy = AllocateNode(node->props_, node); diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 9d96cd4654bbf1fd65c5135d6a8bdc271c6e9443..cbd58b051afde592731ddf2b2ed61854cdfac49e 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -422,7 +422,7 @@ class Graph { // Copies *node, which may belong to another graph, to a new node, // which is returned. Does not copy any edges. *this owns the // returned instance. - Node* CopyNode(Node* node); + Node* CopyNode(const Node* node); // Removes a node from this graph, including all edges from or to it. // *node should not be accessed after calling this function. diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index 0629ff32d00cf7fad00c39f07810aa4a9d57f14f..627309078ac51a25fe2924935c191ec1c4d2a534 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1271,7 +1271,7 @@ void CopyGraph(const Graph& src, Graph* dest) { dest->set_versions(src.versions()); // Copy the nodes - std::unordered_map + std::unordered_map node_map; // "Node in src" -> "Node in *dest" node_map[src.source_node()] = dest->source_node(); node_map[src.sink_node()] = dest->sink_node(); diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc index add80eda23d7887fb06902c0b123c03db8f4cccf..17a174101b2be479bea834a407544b3a74dc08cf 100644 --- a/tensorflow/core/graph/graph_partition.cc +++ b/tensorflow/core/graph/graph_partition.cc @@ -123,8 +123,8 @@ bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { return false; } - Node* src = edge->src(); - Node* dst = edge->dst(); + const Node* src = edge->src(); + const Node* dst = edge->dst(); if (src->assigned_device_name() == dst->assigned_device_name()) { int src_port = edge->src_output(); int dst_port = edge->dst_input(); @@ -141,7 +141,7 @@ bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { // Return true iff (dst, dst_input) is specified on host memory. bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) { - Node* dst = edge->dst(); + const Node* dst = edge->dst(); int dst_port = edge->dst_input(); if (info.device_types[dst->id()] != DEVICE_CPU) { if (edge->IsControlEdge()) return false; diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index 5343e6802d1e75f516925d44ab680b96f4e157da..e9ced4d2b6b2e7bffa0fbe61f546bef0aa9db974 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -222,7 +222,7 @@ Status MklToTfConversionPass::InsertInputConversionNode( BaseType(n->input_type(0))); // Check ordering of edges - for (uint i = 0; i < 4; i++) { + for (uint32 i = 0; i < 4; i++) { CHECK_EQ((edges[i]->dst_input() == i), true); } diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index 138952dcb33e7b1e57cf013147581d20f509e85d..114962c0e4f2969fe539d5b168aaf62d577a7024 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -88,7 +88,7 @@ NodeBuilder& NodeBuilder::ControlInput(Node* src_node) { NodeBuilder& NodeBuilder::ControlInputs(gtl::ArraySlice src_nodes) { control_inputs_.insert(control_inputs_.end(), src_nodes.begin(), src_nodes.end()); - for (Node* src_node : src_nodes) { + for (const Node* src_node : src_nodes) { def_builder_.ControlInput(src_node->name()); } return *this; @@ -127,7 +127,7 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { return Status::OK(); } -void NodeBuilder::AddIndexError(Node* node, int i) { +void NodeBuilder::AddIndexError(const Node* node, int i) { if (node == nullptr) { errors_.emplace_back( strings::StrCat("Attempt to add nullptr Node to node with type ", @@ -140,7 +140,7 @@ void NodeBuilder::AddIndexError(Node* node, int i) { } } -bool NodeBuilder::GetOutputType(Node* node, int i, DataType* dt) { +bool NodeBuilder::GetOutputType(const Node* node, int i, DataType* dt) { bool error; *dt = SafeGetOutput(node, i, &error); if (error) AddIndexError(node, i); diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 86647a49c12085b6850a0e6d2622ef1bb58c513d..f6b7b5674b032cd2b19d69765e7c3b6b6613b3bd 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -120,7 +120,7 @@ class NodeBuilder { const OpDef& op_def() const { return def_builder_.op_def(); } private: - static DataType SafeGetOutput(Node* node, int i, bool* error) { + static DataType SafeGetOutput(const Node* node, int i, bool* error) { if (node != nullptr && i >= 0 && i < node->num_outputs()) { *error = false; return node->output_type(i); @@ -131,11 +131,11 @@ class NodeBuilder { } // If SafeGetOutput indicates a range error, add it to errors_. - void AddIndexError(Node* node, int i); + void AddIndexError(const Node* node, int i); // Set *dt and returns true if i is in range. Combines // SafeGetOutput() and AddIndexError(). - bool GetOutputType(Node* node, int i, DataType* dt); + bool GetOutputType(const Node* node, int i, DataType* dt); NodeDefBuilder def_builder_; std::vector inputs_; diff --git a/tensorflow/core/graph/optimizer_cse.cc b/tensorflow/core/graph/optimizer_cse.cc index 6b452a1d5dca0a636264a3483e4ee9d027fd2e06..4073255db3f7cbcd697f3cb2781e04b3b01634c1 100644 --- a/tensorflow/core/graph/optimizer_cse.cc +++ b/tensorflow/core/graph/optimizer_cse.cc @@ -65,8 +65,8 @@ class OptimizerCSE { }; static void FillInputs(const Node* n, - gtl::InlinedVector* control_edges, - gtl::InlinedVector, 4>* in) { + gtl::InlinedVector* control_edges, + gtl::InlinedVector, 4>* in) { DCHECK_EQ(in->size(), n->num_inputs()); control_edges->clear(); for (const Edge* e : n->in_edges()) { @@ -96,8 +96,8 @@ size_t OptimizerCSE::NodeHash(const Node* n) { const int N_in = n->num_inputs(); strings::StrAppend(&str_to_hash, N_in); - gtl::InlinedVector control_edges; - gtl::InlinedVector, 4> in(N_in); + gtl::InlinedVector control_edges; + gtl::InlinedVector, 4> in(N_in); FillInputs(n, &control_edges, &in); for (const auto& edge : in) { strings::StrAppend(&str_to_hash, edge.first->id(), edge.second); @@ -147,10 +147,10 @@ bool OptimizerCSE::Equivalent(const Node* a, const Node* b, // Compare input sources if (a->num_inputs() != b->num_inputs()) return false; const int N_in = a->num_inputs(); - gtl::InlinedVector a_control_edges; - gtl::InlinedVector b_control_edges; - gtl::InlinedVector, 4> a_in(N_in); - gtl::InlinedVector, 4> b_in(N_in); + gtl::InlinedVector a_control_edges; + gtl::InlinedVector b_control_edges; + gtl::InlinedVector, 4> a_in(N_in); + gtl::InlinedVector, 4> b_in(N_in); FillInputs(a, &a_control_edges, &a_in); FillInputs(b, &b_control_edges, &b_in); if (a_in != b_in) return false; diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index 0d88d1ff723b94783693559926c51c6726a2341b..67b252cb6c576b84de7f823ace2a1c7750d0c35b 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/graph/testlib.h" #include +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" @@ -50,7 +51,8 @@ REGISTER_KERNEL_BUILDER( REGISTER_OP("HostConst") .Output("output: dtype") .Attr("value: tensor") - .Attr("dtype: type"); + .Attr("dtype: type") + .SetShapeFn(shape_inference::UnknownShape); namespace test { namespace graph { diff --git a/tensorflow/core/grappler/clusters/BUILD b/tensorflow/core/grappler/clusters/BUILD index 5b8ce373bcf87a10875e764ba5cdbec96d58c080..b653f902e857ce804f797a016ebde551bf3b6695 100644 --- a/tensorflow/core/grappler/clusters/BUILD +++ b/tensorflow/core/grappler/clusters/BUILD @@ -1,7 +1,12 @@ licenses(["notice"]) # Apache 2.0 +load("//tensorflow:tensorflow.bzl", "if_cuda") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.bzl", "tf_cuda_library") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) filegroup( name = "all_files", @@ -26,13 +31,12 @@ config_setting( tf_cuda_library( name = "utils", srcs = ["utils.cc"], - hdrs = [ - "utils.h", - ], + hdrs = ["utils.h"], visibility = ["//visibility:public"], deps = [ "//third_party/eigen3", "//tensorflow/core:framework", + "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ] + select({ @@ -41,6 +45,21 @@ tf_cuda_library( }), ) +tf_cc_test( + name = "utils_test", + srcs = ["utils_test.cc"], + linkstatic = if_cuda(1, 0), + tags = tf_cuda_tests_tags(), + deps = [ + ":utils", + "//tensorflow/core:gpu_id", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "cluster", srcs = ["cluster.cc"], @@ -104,6 +123,7 @@ cc_library( "//tensorflow/core:core_cpu_lib", "//tensorflow/core:direct_session", "//tensorflow/core:framework", + "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core/grappler:utils", "//tensorflow/core/kernels:ops_util", @@ -114,7 +134,10 @@ tf_cc_test( name = "single_machine_test", srcs = ["single_machine_test.cc"], args = ["--heap_check=local"], # The GPU tracer leaks memory - tags = ["no_gpu"], + tags = [ + "no_cuda_on_cpu_tap", + "no_gpu", + ], deps = [ ":single_machine", "//tensorflow/cc:cc_ops", diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index 862ce4ae8883f394fd299914e245a69f1962f564..8e236c9ee80f30f7aa5c00f32fd137a718215cf3 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -21,6 +21,8 @@ limitations under the License. #include "tensorflow/cc/training/queue_runner.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/grappler/clusters/utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/kernels/ops_util.h" @@ -28,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session.h" namespace tensorflow { @@ -79,14 +82,27 @@ Status SingleMachine::Provision() { std::vector devices; TF_RETURN_IF_ERROR(session_->ListDevices(&devices)); - int gpu_id = 0; for (const auto& dev : devices) { DeviceProperties attr; if (dev.device_type() == "CPU") { attr = GetLocalCPUInfo(); } else if (dev.device_type() == "GPU") { - attr = GetLocalGPUInfo(gpu_id++); - } else { + DeviceNameUtils::ParsedName parsed; + if (!DeviceNameUtils::ParseFullName(dev.name(), &parsed)) { + return errors::InvalidArgument( + strings::StrCat("Not able to parse GPU device name: ", dev.name())); + } + TfGpuId tf_gpu_id(parsed.id); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (!s.ok()) { + return errors::Unavailable("Unknown TF GPU device with id ", + tf_gpu_id.value(), ": ", s.ToString()); + } + attr = GetLocalGPUInfo(cuda_gpu_id); + } else if (dev.device_type().find("XLA") == string::npos) { + // Filter out the fake XLA devices to avoid double counting the actual + // hardware resources that are available. attr.set_type(dev.device_type()); } // Overwrite the memory size since users might have requested to use only a diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index aacd2ccb72df07ac6b31c9bd5b96deca499038e4..b54b34959a53b56022a449ca286ff0ba823f2aa5 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -27,6 +27,9 @@ limitations under the License. #include "include/libxsmm.h" #endif +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" @@ -51,7 +54,7 @@ DeviceProperties GetLocalCPUInfo() { int64 free_mem = port::AvailableRam(); if (free_mem < INT64_MAX) { - device.set_memory_size(free_mem); + device.set_memory_size(free_mem * 1024); } (*device.mutable_environment())["cpu_instruction_set"] = @@ -66,36 +69,40 @@ DeviceProperties GetLocalCPUInfo() { return device; } -DeviceProperties GetLocalGPUInfo(int gpu_id) { +DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) { DeviceProperties device; device.set_type("GPU"); #if GOOGLE_CUDA cudaDeviceProp properties; - cudaError_t error = cudaGetDeviceProperties(&properties, gpu_id); - if (error == cudaSuccess) { - device.set_vendor("NVidia"); - device.set_model(properties.name); - device.set_frequency(properties.clockRate * 1e-3); - device.set_num_cores(properties.multiProcessorCount); - device.set_num_registers(properties.regsPerMultiprocessor); - // For compute capability less than 5, l1 cache size is configurable to - // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For - // compute capability larger or equal to 5, l1 cache (unified with texture - // cache) size is 24 KB. This number may need to be updated for future - // compute capabilities. - device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); - device.set_l2_cache_size(properties.l2CacheSize); - device.set_l3_cache_size(0); - device.set_shared_memory_size_per_multiprocessor( - properties.sharedMemPerMultiprocessor); - device.set_memory_size(properties.totalGlobalMem); - // 8 is the number of bits per byte. 2 is accounted for - // double data rate (DDR). - device.set_bandwidth(properties.memoryBusWidth / 8 * - properties.memoryClockRate * 2); + cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value()); + if (error != cudaSuccess) { + device.set_type("UNKNOWN"); + LOG(ERROR) << "Failed to get device properties, error code: " << error; + return device; } + device.set_vendor("NVIDIA"); + device.set_model(properties.name); + device.set_frequency(properties.clockRate * 1e-3); + device.set_num_cores(properties.multiProcessorCount); + device.set_num_registers(properties.regsPerMultiprocessor); + // For compute capability less than 5, l1 cache size is configurable to + // either 16 KB or 48 KB. We use the initial configuration 16 KB here. For + // compute capability larger or equal to 5, l1 cache (unified with texture + // cache) size is 24 KB. This number may need to be updated for future + // compute capabilities. + device.set_l1_cache_size((properties.major < 5) ? 16 * 1024 : 24 * 1024); + device.set_l2_cache_size(properties.l2CacheSize); + device.set_l3_cache_size(0); + device.set_shared_memory_size_per_multiprocessor( + properties.sharedMemPerMultiprocessor); + device.set_memory_size(properties.totalGlobalMem); + // 8 is the number of bits per byte. 2 is accounted for + // double data rate (DDR). + device.set_bandwidth(properties.memoryBusWidth / 8 * + properties.memoryClockRate * 2); + (*device.mutable_environment())["architecture"] = strings::StrCat(properties.major, ".", properties.minor); (*device.mutable_environment())["cuda"] = strings::StrCat(CUDA_VERSION); @@ -106,18 +113,26 @@ DeviceProperties GetLocalGPUInfo(int gpu_id) { } DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) { + DeviceProperties unknown; + unknown.set_type("UNKNOWN"); + if (device.type == "CPU") { return GetLocalCPUInfo(); } else if (device.type == "GPU") { if (device.has_id) { - return GetLocalGPUInfo(device.id); + TfGpuId tf_gpu_id(device.id); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (!s.ok()) { + LOG(ERROR) << s; + return unknown; + } + return GetLocalGPUInfo(cuda_gpu_id); } else { - return GetLocalGPUInfo(0); + return GetLocalGPUInfo(CudaGpuId(0)); } } - DeviceProperties result; - result.set_type("UNKNOWN"); - return result; + return unknown; } } // end namespace grappler diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h index 191942040a1fdd276bb50f799ce314389c2cb0fe..df8e7dca44ad637aed8a6a2e87fc6e20bdf62606 100644 --- a/tensorflow/core/grappler/clusters/utils.h +++ b/tensorflow/core/grappler/clusters/utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ #define TENSORFLOW_GRAPPLER_CLUSTERS_UTILS_H_ +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" #include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/util/device_name_utils.h" @@ -27,7 +28,7 @@ DeviceProperties GetLocalCPUInfo(); // Returns the DeviceProperties for the specified GPU attached to the server on // which grappler is running. -DeviceProperties GetLocalGPUInfo(int gpu_id); +DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id); // Returns the DeviceProperties of the specified device DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device); diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..74218adbac4eda3a7a780933b8116cfd2b7a1b18 --- /dev/null +++ b/tensorflow/core/grappler/clusters/utils_test.cc @@ -0,0 +1,100 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/clusters/utils.h" + +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/device_properties.pb.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(UtilsTest, GetLocalGPUInfo) { + GpuIdManager::TestOnlyReset(); +#if GOOGLE_CUDA + LOG(INFO) << "CUDA is enabled."; + DeviceProperties properties; + + // Invalid CUDA GPU ID. + properties = GetLocalGPUInfo(CudaGpuId(100)); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Succeed when a valid CUDA GPU id was inserted. + properties = GetLocalGPUInfo(CudaGpuId(0)); + EXPECT_EQ("GPU", properties.type()); + EXPECT_EQ("NVIDIA", properties.vendor()); +#else + LOG(INFO) << "CUDA is not enabled."; + DeviceProperties properties; + + properties = GetLocalGPUInfo(CudaGpuId(0)); + EXPECT_EQ("GPU", properties.type()); + + properties = GetLocalGPUInfo(CudaGpuId(100)); + EXPECT_EQ("GPU", properties.type()); +#endif +} + +TEST(UtilsTest, GetDeviceInfo) { + GpuIdManager::TestOnlyReset(); + DeviceNameUtils::ParsedName device; + DeviceProperties properties; + + // Invalid type. + properties = GetDeviceInfo(device); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Cpu info. + device.type = "CPU"; + properties = GetDeviceInfo(device); + EXPECT_EQ("CPU", properties.type()); + + // No TF GPU id provided. + device.type = "GPU"; + device.has_id = false; + properties = GetDeviceInfo(device); + EXPECT_EQ("GPU", properties.type()); +#if GOOGLE_CUDA + EXPECT_EQ("NVIDIA", properties.vendor()); +#endif + + // TF to CUDA GPU id mapping entry doesn't exist. + device.has_id = true; + device.id = 0; + properties = GetDeviceInfo(device); + EXPECT_EQ("UNKNOWN", properties.type()); + +#if GOOGLE_CUDA + // Invalid CUDA GPU id. + GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100)); + properties = GetDeviceInfo(device); + EXPECT_EQ("UNKNOWN", properties.type()); + + // Valid CUDA GPU id. + GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0)); + device.id = 1; + properties = GetDeviceInfo(device); + EXPECT_EQ("GPU", properties.type()); + EXPECT_EQ("NVIDIA", properties.vendor()); +#endif +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD index 0fe01e9c9e094ebfa7fd1e6200d775ef61775184..5336df1f51dbb5dd5f48857a088ece1b1a04dbb5 100644 --- a/tensorflow/core/grappler/costs/BUILD +++ b/tensorflow/core/grappler/costs/BUILD @@ -142,6 +142,7 @@ tf_cuda_library( "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:graph", + "//tensorflow/core:gpu_id", "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc index 76db1afd4a2831adcc8b9f7c54d4f3309d2a035c..29ef317e46f13bd64847fd898fcb2eb9fee67f1c 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc @@ -245,6 +245,8 @@ OpLevelCostEstimator::OpLevelCostEstimator() { {"Add", Eigen::internal::functor_traits< Eigen::internal::scalar_sum_op>::Cost}, {"ApproximateEqual", 1}, + {"BiasAdd", Eigen::internal::functor_traits< + Eigen::internal::scalar_sum_op>::Cost}, {"Div", Eigen::internal::functor_traits< Eigen::internal::scalar_quotient_op>::Cost}, {"Equal", 1}, @@ -718,24 +720,87 @@ int64 OpLevelCostEstimator::CountBatchMatMulOperations( return ops; } +bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, + TensorShapeProto* tensor_shape_proto) { + tensor_shape_proto->Clear(); + // First convert TensorProto into Tensor class so that it correctly parses + // data values within TensorProto (whether it's in int_val, int64_val, + // tensor_content, or anything. + Tensor tensor(tensor_proto.dtype()); + if (!tensor.FromProto(tensor_proto)) { + LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- " + << "failed to parse TensorProto: " + << tensor_proto.DebugString(); + return false; + } + if (tensor.dims() != 1) { + LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- " + << "tensor is not 1D: " << tensor.dims(); + return false; + } + // Then, convert it back to TensorProto using AsProtoField, which makes sure + // the data is in int_val, int64_val, or such repeated data fields, not in + // tensor_content. + TensorProto temp_tensor; + tensor.AsProtoField(&temp_tensor); + +#define TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(type) \ + do { \ + for (const auto& value : temp_tensor.type##_val()) { \ + tensor_shape_proto->add_dim()->set_size(value); \ + } \ + } while (0) + + if (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT16 || + tensor.dtype() == DT_INT8 || tensor.dtype() == DT_UINT8) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int); + } else if (tensor.dtype() == DT_INT64) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(int64); + } else if (tensor.dtype() == DT_UINT32) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint32); + } else if (tensor.dtype() == DT_UINT64) { + TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO(uint64); + } else { + LOG(WARNING) << "GetTensorShapeProtoFromTensorProto() -- " + << "Unsupported dtype: " << tensor.dtype(); + return false; + } +#undef TENSOR_VALUES_TO_TENSOR_SHAPE_PROTO + + return true; +} + // TODO(cliffy): Dedup this method and CountConv2DBackpropFilterOperations. int64 OpLevelCostEstimator::CountConv2DBackpropInputOperations( const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, bool* found_unknown_shapes) const { int64 ops = 0; - if (op_features.op() != kConv2dBackpropInput) { - LOG(ERROR) << "Invalid Operation"; + DCHECK_EQ(kConv2dBackpropInput, op_features.op()); + + if (op_features.inputs_size() < 2) { + *found_unknown_shapes = true; return ops; } - if (op_features.outputs_size() != 1) { - // Need _output_shapes for input shape. - LOG(ERROR) << "No output shape in Conv2DBackpropInput op."; - return ops; + TensorShapeProto input_shape; + bool shape_found = false; + if (op_features.inputs(0).has_value()) { + const TensorProto& value = op_features.inputs(0).value(); + shape_found = GetTensorShapeProtoFromTensorProto(value, &input_shape); + } + if (!shape_found && op_features.outputs_size() == 1) { + input_shape = op_features.outputs(0).shape(); + shape_found = true; + } + if (!shape_found) { + // Set the minimum filter size that's feasible. + for (int i = 0; i < 4; ++i) { + input_shape.add_dim()->set_size(1); + } + *found_unknown_shapes = true; } - const auto& input_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( input_shape, op_features.inputs(1).shape(), op_features, found_unknown_shapes); @@ -758,18 +823,30 @@ int64 OpLevelCostEstimator::CountConv2DBackpropFilterOperations( const OpInfo& op_features, ConvolutionDimensions* returned_conv_dims, bool* found_unknown_shapes) const { int64 ops = 0; - if (op_features.op() != kConv2dBackpropFilter) { - LOG(ERROR) << "Invalid Operation"; - return ops; + DCHECK_EQ(kConv2dBackpropFilter, op_features.op()); + + TensorShapeProto filter_shape; + bool shape_found = false; + if (op_features.inputs_size() >= 2 && op_features.inputs(1).has_value()) { + const TensorProto& value = op_features.inputs(1).value(); + shape_found = GetTensorShapeProtoFromTensorProto(value, &filter_shape); + } + if (!shape_found && op_features.outputs_size() == 1) { + filter_shape = op_features.outputs(0).shape(); + shape_found = true; + } + if (!shape_found) { + // Set the minimum filter size that's feasible. + for (int i = 0; i < 4; ++i) { + filter_shape.add_dim()->set_size(1); + } + *found_unknown_shapes = true; } - if (op_features.outputs_size() != 1) { - // Need _output_shapes for input shape. - LOG(ERROR) << "No output shape in Conv2DBackpropFilter op."; + if (op_features.inputs_size() < 1) { + *found_unknown_shapes = true; return ops; } - - const auto& filter_shape = op_features.outputs(0).shape(); ConvolutionDimensions conv_dims = ConvolutionDimensionsFromInputs( op_features.inputs(0).shape(), filter_shape, op_features, found_unknown_shapes); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h index a292e5e97fe52383648d74b08bb7a384b6278446..7bb530fe31a9f70d168ae16783fac7d487e5f12d 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h @@ -28,6 +28,9 @@ limitations under the License. namespace tensorflow { namespace grappler { +bool GetTensorShapeProtoFromTensorProto(const TensorProto& tensor_proto, + TensorShapeProto* tensor_shape_proto); + class OpLevelCostEstimator { public: OpLevelCostEstimator(); diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc index 60fc783472d2b6a1d50eb52e912da1fccbe8cf08..4790b9bab2c7d67e7a29d45aaf9f964c470c63df 100644 --- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc +++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/test.h" @@ -97,47 +99,81 @@ OpContext DescribeBatchMatMul(const std::vector& dims_a, // Wrangles the minimum number of proto fields to set up a 4D Tensor for cost // estimation purposes. void DescribeTensor4D(int dim0, int dim1, int dim2, int dim3, - OpInfo* op_features) { - auto input = op_features->add_inputs(); - auto shape = input->mutable_shape(); + OpInfo::TensorProperties* tensor) { + auto shape = tensor->mutable_shape(); shape->add_dim()->set_size(dim0); shape->add_dim()->set_size(dim1); shape->add_dim()->set_size(dim2); shape->add_dim()->set_size(dim3); - input->set_dtype(DT_FLOAT); + tensor->set_dtype(DT_FLOAT); } -// Returns an OpInfo for Conv2D with the minimum set of fields set up. +// DescribeConvolution constructs an OpContext for a Conv2D applied to an input +// tensor with shape (batch, ix, iy, iz1) and a kernel tensor with shape +// (kx, ky, iz2, oz). OpContext DescribeConvolution(int batch, int ix, int iy, int iz1, int iz2, int kx, int ky, int oz) { OpContext op_context; SetCpuDevice(&op_context.op_info); op_context.op_info.set_op("Conv2D"); - DescribeTensor4D(batch, ix, iy, iz1, &op_context.op_info); - DescribeTensor4D(kx, ky, iz2, oz, &op_context.op_info); + DescribeTensor4D(batch, ix, iy, iz1, op_context.op_info.add_inputs()); + DescribeTensor4D(kx, ky, iz2, oz, op_context.op_info.add_inputs()); + return op_context; } -OpContext DescribeOp(const string& op, int size1, int size2) { +// DescribeUnaryOp constructs an OpContext for the given operation applied to +// a 4-tensor with shape (size1, 1, 1, 1). +OpContext DescribeUnaryOp(const string& op, int size1) { OpContext op_context; SetCpuDevice(&op_context.op_info); op_context.op_info.set_op(op); - DescribeTensor4D(size1, 1, 1, 1, &op_context.op_info); - DescribeTensor4D(2 * size1, size2, 1, 1, &op_context.op_info); + DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs()); + DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_outputs()); + + return op_context; +} - auto output = op_context.op_info.add_outputs(); - auto shape = output->mutable_shape(); - shape->add_dim()->set_size(2 * size1); - shape->add_dim()->set_size(size2); - shape->add_dim()->set_size(1); - shape->add_dim()->set_size(1); - output->set_dtype(DT_FLOAT); +// DescribeBinaryOp constructs an OpContext for the given operation applied to +// a 4-tensor with dimensions (size1, 1, 1, 1) and a 4-tensor with dimensions +// (2 * size1, size2, 1, 1). +// +// The choice of dimension here is arbitrary, and is used strictly to test the +// cost model for applying elementwise operations to tensors with unequal +// dimension values. +OpContext DescribeBinaryOp(const string& op, int size1, int size2) { + OpContext op_context; + SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op(op); + + DescribeTensor4D(size1, 1, 1, 1, op_context.op_info.add_inputs()); + DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_inputs()); + DescribeTensor4D(2 * size1, size2, 1, 1, op_context.op_info.add_outputs()); + + return op_context; +} +// DescribeBiasAdd constructs an OpContext for a BiasAdd applied to a 4-tensor +// with dimensions (1, 1, size2, size1) and a bias with dimension (size1), +// according to the constraint that the bias must be 1D with size equal to that +// of the last dimension of the input value. +OpContext DescribeBiasAdd(int size1, int size2) { + OpContext op_context; SetCpuDevice(&op_context.op_info); + op_context.op_info.set_op("BiasAdd"); + + DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_inputs()); + DescribeTensor4D(1, 1, size2, size1, op_context.op_info.add_outputs()); + + auto bias = op_context.op_info.add_inputs(); + bias->mutable_shape()->add_dim()->set_size(size1); + bias->set_dtype(DT_FLOAT); + return op_context; } + } // namespace class OpLevelCostEstimatorTest : public ::testing::Test { @@ -164,8 +200,24 @@ class OpLevelCostEstimatorTest : public ::testing::Test { OpLevelCostEstimator estimator_; }; +TEST_F(OpLevelCostEstimatorTest, BiasAddExecutionTime) { + auto cost = PredictCosts(DescribeBiasAdd(1000, 10)); + EXPECT_EQ(Costs::Duration(8400), cost.memory_time); + EXPECT_EQ(Costs::Duration(1000), cost.compute_time); + EXPECT_EQ(Costs::Duration(9400), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + +TEST_F(OpLevelCostEstimatorTest, Conv2DExecutionTime) { + auto cost = PredictCosts(DescribeConvolution(16, 19, 19, 48, 48, 5, 5, 256)); + EXPECT_EQ(Costs::Duration(233780), cost.memory_time); + EXPECT_EQ(Costs::Duration(354877440), cost.compute_time); + EXPECT_EQ(Costs::Duration(355111220), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) { - auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1)); + auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1)); EXPECT_EQ(Costs::Duration(2000), cost.memory_time); EXPECT_EQ(Costs::Duration(0), cost.compute_time); EXPECT_EQ(Costs::Duration(2000), cost.execution_time); @@ -174,7 +226,7 @@ TEST_F(OpLevelCostEstimatorTest, DummyExecutionTime) { TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) { SetComputeMemoryOverlap(true); - auto cost = PredictCosts(DescribeOp("Dummy", 1000, 1)); + auto cost = PredictCosts(DescribeBinaryOp("Dummy", 1000, 1)); EXPECT_EQ(Costs::Duration(2000), cost.memory_time); EXPECT_EQ(Costs::Duration(0), cost.compute_time); EXPECT_EQ(Costs::Duration(2000), cost.execution_time); // max(2000, 200) @@ -183,7 +235,7 @@ TEST_F(OpLevelCostEstimatorTest, ExecutionTimeSumOrMax) { } TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) { - auto cost = PredictCosts(DescribeOp("Mul", 1000, 1)); + auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 1)); EXPECT_EQ(Costs::Duration(2000), cost.memory_time); EXPECT_EQ(Costs::Duration(200), cost.compute_time); EXPECT_EQ(Costs::Duration(2200), cost.execution_time); @@ -191,7 +243,7 @@ TEST_F(OpLevelCostEstimatorTest, MulExecutionTime) { } TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) { - auto cost = PredictCosts(DescribeOp("Mul", 1000, 2)); + auto cost = PredictCosts(DescribeBinaryOp("Mul", 1000, 2)); EXPECT_EQ(Costs::Duration(3600), cost.memory_time); EXPECT_EQ(Costs::Duration(400), cost.compute_time); EXPECT_EQ(Costs::Duration(4000), cost.execution_time); @@ -199,13 +251,21 @@ TEST_F(OpLevelCostEstimatorTest, MulBroadcastExecutionTime) { } TEST_F(OpLevelCostEstimatorTest, ModExecutionTime) { - auto cost = PredictCosts(DescribeOp("Mod", 1000, 1)); + auto cost = PredictCosts(DescribeBinaryOp("Mod", 1000, 1)); EXPECT_EQ(Costs::Duration(2000), cost.memory_time); EXPECT_EQ(Costs::Duration(1600), cost.compute_time); EXPECT_EQ(Costs::Duration(3600), cost.execution_time); EXPECT_FALSE(cost.inaccurate); } +TEST_F(OpLevelCostEstimatorTest, ReluExecutionTime) { + auto cost = PredictCosts(DescribeUnaryOp("Relu", 1000)); + EXPECT_EQ(Costs::Duration(800), cost.memory_time); + EXPECT_EQ(Costs::Duration(100), cost.compute_time); + EXPECT_EQ(Costs::Duration(900), cost.execution_time); + EXPECT_FALSE(cost.inaccurate); +} + TEST_F(OpLevelCostEstimatorTest, UnknownOrPartialShape) { EXPECT_FALSE(PredictCosts(DescribeMatMul(2, 4, 7, 7)).inaccurate); EXPECT_TRUE(PredictCosts(DescribeMatMul(-1, 4, 7, 7)).inaccurate); @@ -247,5 +307,108 @@ TEST_F(OpLevelCostEstimatorTest, BatchMatMul) { EXPECT_NE(matmul_inaccurate, batch_matmul_inaccurate); } +// Helper functions for testing GetTensorShapeProtoFromTensorProto(). +void GetTensorProto(const DataType dtype, const std::vector& shape, + const std::vector values, const bool tensor_content, + TensorProto* tensor_proto) { + tensor_proto->Clear(); + TensorProto temp_tensor_proto; + temp_tensor_proto.set_dtype(dtype); + for (const auto& x : shape) { + temp_tensor_proto.mutable_tensor_shape()->add_dim()->set_size(x); + } + for (const auto& x : values) { + if (dtype == DT_INT64) { + temp_tensor_proto.add_int64_val(x); + } else if (dtype == DT_INT32 || dtype == DT_INT16 || dtype == DT_INT8 || + dtype == DT_UINT8) { + temp_tensor_proto.add_int_val(x); + } else if (dtype == DT_UINT32) { + temp_tensor_proto.add_uint32_val(x); + } else if (dtype == DT_UINT64) { + temp_tensor_proto.add_uint64_val(x); + } else { + CHECK(false) << "Unsupported dtype: " << dtype; + } + } + Tensor tensor(dtype); + CHECK(tensor.FromProto(temp_tensor_proto)); + if (tensor_content) { + tensor.AsProtoTensorContent(tensor_proto); + } else { + tensor.AsProtoField(tensor_proto); + } +} + +void ExpectTensorShape(const std::vector& expected, + const TensorShapeProto& tensor_shape_proto) { + TensorShape tensor_shape_expected(expected); + TensorShape tensor_shape(tensor_shape_proto); + + LOG(INFO) << "Expected: " << tensor_shape_expected.DebugString(); + LOG(INFO) << "TensorShape: " << tensor_shape.DebugString(); + EXPECT_TRUE(tensor_shape_expected == tensor_shape); +} + +TEST_F(OpLevelCostEstimatorTest, GetTensorShapeProtoFromTensorProto) { + TensorProto tensor_proto; + TensorShapeProto tensor_shape_proto; + + // Dimention larger than max value; should fail while converting to Tensor + // class. + tensor_proto.mutable_tensor_shape()->add_dim()->set_size(255); + EXPECT_FALSE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + + tensor_proto.Clear(); + // Expect only 1D shape. + tensor_proto.mutable_tensor_shape()->add_dim()->set_size(1); + tensor_proto.mutable_tensor_shape()->add_dim()->set_size(2); + EXPECT_FALSE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + + // Expect only handle integer data types. + GetTensorProto(DT_FLOAT, {}, {}, /*tensor_content=*/false, &tensor_proto); + EXPECT_FALSE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + + // Check GetTensorShapeProtoFromTensorProto() resturns correct values. + { + std::vector shape_expected = {10, 20, 30, 40}; + GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/false, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } + + { + std::vector shape_expected = {40, 20, 90, 40}; + GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/false, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } + + { + std::vector shape_expected = {10, 20, 30, 40}; + GetTensorProto(DT_INT32, {4}, shape_expected, /*tensor_content=*/true, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } + + { + std::vector shape_expected = {40, 20, 90, 40}; + GetTensorProto(DT_INT64, {4}, shape_expected, /*tensor_content=*/true, + &tensor_proto); + EXPECT_TRUE( + GetTensorShapeProtoFromTensorProto(tensor_proto, &tensor_shape_proto)); + ExpectTensorShape(shape_expected, tensor_shape_proto); + } +} + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 602f69f12ea9d24ebd94da73a2a76d1992f3bfb1..076945d5c626b9609448e339fcbd96de3e9d137f 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -26,6 +26,8 @@ limitations under the License. #include "cuda/include/cudnn.h" #endif +#include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_manager.h" #include "tensorflow/core/framework/allocation_description.pb.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" @@ -200,17 +202,25 @@ std::vector FindInputFeatures( } DeviceProperties GetDeviceInfo(const string& device_str) { + DeviceProperties unknown; + unknown.set_type("UNKNOWN"); + DeviceNameUtils::ParsedName parsed; if (DeviceNameUtils::ParseFullName(device_str, &parsed)) { if (parsed.type == "GPU") { - return GetLocalGPUInfo(parsed.id); + TfGpuId tf_gpu_id(parsed.id); + CudaGpuId cuda_gpu_id; + Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + if (!s.ok()) { + LOG(ERROR) << s; + return unknown; + } + return GetLocalGPUInfo(cuda_gpu_id); } else if (parsed.type == "CPU") { return GetLocalCPUInfo(); } } - DeviceProperties device; - device.set_type("UNKNOWN"); - return device; + return unknown; } DeviceProperties GetDeviceInfo(const CostGraphDef::Node& node) { diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 020492a3e9e23a8360a5e8804bc51ba6c5de67d1..3ac3ae0f8f835226bbc3ec5d6cec6cb890a6998f 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/grappler/costs/utils.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -324,7 +325,7 @@ Status VirtualScheduler::Init() { // Get the nodes that would run to output fetch_nodes. bool ill_formed = false; - std::vector nodes = + const std::vector fetch_fanin_nodes = ComputeTransitiveFanin(graph, fetch_nodes, &ill_formed); if (ill_formed) { return errors::InvalidArgument( @@ -338,7 +339,7 @@ Status VirtualScheduler::Init() { // exactly the same as those executed for real. One possible discrepancy could // be the control flow nodes, where tf only executes one path. std::unordered_map name_to_node; - for (const auto& node : nodes) { + for (const auto& node : fetch_fanin_nodes) { name_to_node[node->name()] = node; } @@ -359,14 +360,22 @@ Status VirtualScheduler::Init() { // Build node_map; for each node, create its NodeState and connect its inputs // and outputs. - for (const auto* curr_node : nodes) { + for (const auto* curr_node : fetch_fanin_nodes) { auto& curr_node_state = GetNodeStateOrCreateIt(curr_node); const string curr_node_device = DeviceName(curr_node); std::vector inputs; if (IsRecv(*curr_node)) { const auto& attr = curr_node->attr(); - const NodeDef* send = name_to_send[attr.at("tensor_name").s()]; - inputs = {send->name()}; + if (attr.count("tensor_name")) { + const auto& send_node_name = attr.at("tensor_name").s(); + auto it = name_to_send.find(send_node_name); + // If there is a _Send associated with the curr_node (_Recv), add it as + // input. + if (it != name_to_send.end()) { + const NodeDef* send = it->second; + inputs = {send->name()}; + } + } } else { for (const string& input : curr_node->input()) { inputs.push_back(input); @@ -425,9 +434,11 @@ Status VirtualScheduler::Init() { feed_nodes.find(curr_node->name()) != feed_nodes.end(); // Default case: node without inputs are ready at time 0. - const bool has_no_inputs = curr_node->input().empty(); + // Note that we check inputs vector which may be different to + // curr_node->input(); e.g., we add Send as input to Recv. + const bool has_no_inputs = inputs.empty(); - if (!IsRecv(*curr_node) && (given_as_feed || has_no_inputs)) { + if (given_as_feed || has_no_inputs) { curr_node_state.time_ready = Costs::Duration(); ready_nodes_->AddNode(curr_node); VLOG(3) << "Added ready node: " << curr_node->name(); @@ -446,13 +457,16 @@ Status VirtualScheduler::Init() { } if (ready_nodes_->Empty()) { - return Status(error::UNAVAILABLE, "No ready nodes in the graph."); + return errors::InvalidArgument("No ready nodes in the graph."); } - if (!feed_nodes.empty()) - LOG(ERROR) << "Some feed nodes were not found in the graph: " - << str_util::Join(feed_nodes, ","); - + if (!feed_nodes.empty()) { + // This isn't always a bug: when the caller hasn't specified the exact list + // of feed and fetch nodes, by default we consider all placeholders as feed + // nodes, but some of them may not be needed for the default fetch node. + VLOG(1) << "Some feed nodes were not consumed by the fetch fanin: " + << str_util::Join(feed_nodes, ","); + } initialized_ = true; return Status::OK(); } diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 53dcb497a6453dfa70c1215352e74e96796ebeb7..f9154e42f984c8dd8e774b83750b41a48087d7bb 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -205,6 +205,25 @@ class VirtualSchedulerTest : public ::testing::Test { dependency_["out"] = {"x", "y", "z", "w"}; } + // Graph with some placeholder feed nodes that are not in the fetch fan-in. + void CreateGrapplerItemWithUnnecessaryPlaceholderNodes() { + Scope s = Scope::NewRootScope().WithDevice(kCPU0); + auto unnecessary = ops::Placeholder(s.WithOpName("unnecessary"), DT_FLOAT); + auto x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT); + + GraphDef def; + TF_CHECK_OK(s.ToGraphDef(&def)); + + grappler_item_.reset(new GrapplerItem); + grappler_item_->id = "test_extra_placeholders"; + grappler_item_->graph = def; + grappler_item_->fetch = {"x"}; + + // Grappler Item Builder puts all placeholder nodes into the feed + // list by default. + grappler_item_->feed = {{"x", Tensor()}, {"unnecessary", Tensor()}}; + } + // NoOp that takes 7 NoOps as control dependency. void CreateGrapplerItemWithControlDependency() { Scope s = Scope::NewRootScope().WithDevice(kCPU0); @@ -394,6 +413,63 @@ versions { grappler_item_->fetch = {"Recv"}; } + void CreateGrapplerItemWithRecvWithoutSend() { + const string gdef_ascii = R"EOF( +node { + name: "Recv" + op: "_Recv" + device: "/job:localhost/replica:0/task:0/device:CPU:0" + attr { + key: "client_terminated" + value { + b: false + } + } + attr { + key: "recv_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device" + value { + s: "/job:localhost/replica:0/task:0/device:CPU:0" + } + } + attr { + key: "send_device_incarnation" + value { + i: 0 + } + } + attr { + key: "tensor_name" + value { + s: "test" + } + } + attr { + key: "tensor_type" + value { + type: DT_FLOAT + } + } +} +library { +} +versions { + producer: 24 +} + )EOF"; + + grappler_item_.reset(new GrapplerItem); + CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, + &grappler_item_->graph)); + grappler_item_->id = "test_graph"; + grappler_item_->fetch = {"Recv"}; + } + // A simple while loop void CreateGrapplerItemWithLoop() { // Test graph produced in python using: @@ -1700,6 +1776,16 @@ TEST_F(VirtualSchedulerTest, MemoryUsage) { cpu_state.mem_usage_snapshot_at_peak); } +TEST_F(VirtualSchedulerTest, UnnecessaryFeedNodes) { + CreateGrapplerItemWithUnnecessaryPlaceholderNodes(); + InitScheduler(); + + // Test that scheduler can run graphs with extra unnecessary feed nodes. + auto ops_executed = RunScheduler(""); + ASSERT_EQ(1, ops_executed.size()); + ASSERT_EQ(ops_executed.count("x"), 1); +} + TEST_F(VirtualSchedulerTest, ControlDependency) { // Init. CreateGrapplerItemWithControlDependency(); @@ -2015,5 +2101,17 @@ TEST_F(VirtualSchedulerTest, GraphWithSendRecvDifferentDevice) { 0); EXPECT_GT(ops_executed.count("Recv"), 0); } + +TEST_F(VirtualSchedulerTest, GraphWihtOnlyRecv) { + // Init. + CreateGrapplerItemWithRecvWithoutSend(); + InitScheduler(); + + // Run the scheduler. + auto ops_executed = RunScheduler(""); + + // Recv without Send will be treated as initially ready node. + EXPECT_GT(ops_executed.count("Recv"), 0); +} } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index 2f8549cf395f6b78154f7a6faf3fea06ea6c56c4..ad86356504e06d31ccc0a315fbd6991e49df0f19 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -32,6 +32,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef&& graphDef) { feed = other.feed; fetch = other.fetch; init_ops = other.init_ops; + keep_ops = other.keep_ops; expected_init_time = other.expected_init_time; save_op = other.save_op; restore_op = other.restore_op; @@ -82,6 +83,9 @@ std::unordered_set GrapplerItem::NodesToPreserve() const { for (const auto& node : init_ops) { result.insert(NodeName(node)); } + for (const auto& node : keep_ops) { + result.insert(NodeName(node)); + } if (!save_op.empty()) { result.insert(NodeName(save_op)); } diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 302685972a7f2908278a881112db9dbfb53c1c1a..06bba544c315476219ee83684df59a3da8720eea 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -58,6 +58,11 @@ struct GrapplerItem { // Queue runner(s) required to run the queue(s) of this model. std::vector queue_runners; + // List of op names to keep in the graph. This includes nodes that are + // referenced in various collections, and therefore must be preserved to + // ensure that the optimized metagraph can still be loaded. + std::vector keep_ops; + // Return the set of node evaluated during a regular train/inference step. std::vector MainOpsFanin() const; // Return the set of node run to populate the queues (if any). @@ -66,7 +71,8 @@ struct GrapplerItem { std::vector InitOpsFanin() const; // Return the set of variables accessed during a regular train/inference step. std::vector MainVariables() const; - // Return a set of node names that must be preserved. + // Return a set of node names that must be preserved. This includes feed and + // fetch nodes, keep_ops, init_ops. std::unordered_set NodesToPreserve() const; }; diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 7ba498dd06409635d7dfc282ab29f1133e299c9b..5ac52eefe1144e06f1e10f9c99dcef7591deb880 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -296,6 +296,14 @@ std::unique_ptr GrapplerItemFromMetaGraphDef( } } + // Add each node referenced in a collection to the list of nodes to keep. + for (const auto& col : meta_graph.collection_def()) { + const CollectionDef& collection = col.second; + for (const string& node : collection.node_list().value()) { + new_item->keep_ops.push_back(NodeName(node)); + } + } + for (auto& node : *new_item->graph.mutable_node()) { if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") { if (node.attr().count("dtype") == 0) { diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index fdf4540540b4b9f3d64ea767240ca4ea0c353d48..9b3755ddce61b0e5e44c8f5eacb18b69be63b043 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -256,6 +256,10 @@ bool IsRestore(const NodeDef& node) { node.op() == "RestoreSlice"); } +bool IsReverse(const NodeDef& node) { + return node.op() == "Reverse" || node.op() == "ReverseV2"; +} + bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; } bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; } @@ -272,6 +276,10 @@ bool IsShape(const NodeDef& node) { return node.op() == "Shape"; } bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; } +bool IsShuffle(const NodeDef& node) { + return node.op() == "Shuffle" || node.op() == "RandomShuffle"; +} + bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; } bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; } @@ -346,7 +354,8 @@ bool IsFreeOfSideEffect(const NodeDef& node) { return false; } const OpDef* op_def = nullptr; - Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); + const string& op_name = node.op(); + Status status = OpRegistry::Global()->LookUpOpDef(op_name, &op_def); if (!status.ok()) { return false; } @@ -360,7 +369,8 @@ bool IsFreeOfSideEffect(const NodeDef& node) { } } // Some nodes do in-place updates on regular tensor inputs. - if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace")) { + if (GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace") || + StringPiece(op_name).starts_with("Inplace")) { return false; } return true; diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 9cda40c0a6515caa9754d0c2f4f50a32f9fe8d98..1fa43a9b66b93c4dae1c30943b8466043af327ec 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -100,6 +100,7 @@ bool IsRecv(const NodeDef& node); bool IsReduction(const NodeDef& node); bool IsReshape(const NodeDef& node); bool IsRestore(const NodeDef& node); +bool IsReverse(const NodeDef& node); bool IsReverseV2(const NodeDef& node); bool IsRsqrtGrad(const NodeDef& node); bool IsSelect(const NodeDef& node); @@ -108,6 +109,7 @@ bool IsSend(const NodeDef& node); bool IsSlice(const NodeDef& node); bool IsShape(const NodeDef& node); bool IsShapeN(const NodeDef& node); +bool IsShuffle(const NodeDef& node); bool IsSigmoidGrad(const NodeDef& node); bool IsSoftplusGrad(const NodeDef& node); bool IsSoftsignGrad(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 3432de9dcd380f5e399a01604a8433c34a356b1e..a52d1c8df2981137727e6598590c764869c8b450 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -1,6 +1,9 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") filegroup( name = "all_files", @@ -157,6 +160,18 @@ cc_library( ], ) +cc_library( + name = "custom_graph_optimizer", + hdrs = [ + "custom_graph_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:lib", + ], +) + cc_library( name = "arithmetic_optimizer", srcs = ["arithmetic_optimizer.cc"], @@ -270,9 +285,36 @@ tf_cc_test( ], ) +tf_kernel_library( + name = "gpu_swapping_kernels", + srcs = [ + "gpu_swapping_kernels.cc", + ], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "gpu_swapping_ops", + srcs = [ + "gpu_swapping_ops.cc", + ], + deps = [ + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + cc_library( name = "memory_optimizer", - srcs = ["memory_optimizer.cc"], + srcs = [ + "memory_optimizer.cc", + ], hdrs = [ "memory_optimizer.h", ], @@ -282,6 +324,7 @@ cc_library( ":graph_rewriter", ":static_schedule", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:graph_view", "//tensorflow/core/grappler:grappler_item", @@ -292,10 +335,13 @@ cc_library( "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/utils:traversal", - ], + ] + if_cuda([ + ":gpu_swapping_kernels", + ":gpu_swapping_ops", + ]), ) -tf_cc_test( +tf_cc_test_gpu( name = "memory_optimizer_test", srcs = ["memory_optimizer_test.cc"], deps = [ @@ -368,9 +414,12 @@ cc_library( ":arithmetic_optimizer", ":auto_parallel", ":constant_folding", + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", ":dependency_optimizer", ":graph_optimizer", ":layout_optimizer", + ":loop_optimizer", ":memory_optimizer", ":model_pruner", "//tensorflow/core:framework", @@ -380,3 +429,81 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", ], ) + +tf_cc_test( + name = "meta_optimizer_test", + srcs = ["meta_optimizer_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":meta_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) + +cc_library( + name = "custom_graph_optimizer_registry", + srcs = ["custom_graph_optimizer_registry.cc"], + hdrs = ["custom_graph_optimizer_registry.h"], + visibility = ["//visibility:public"], + deps = [ + ":custom_graph_optimizer", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "custom_graph_optimizer_registry_test", + size = "small", + srcs = ["custom_graph_optimizer_registry_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "loop_optimizer", + srcs = ["loop_optimizer.cc"], + hdrs = [ + "loop_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + ], +) + +tf_cc_test( + name = "loop_optimizer_test", + size = "small", + srcs = ["loop_optimizer_test.cc"], + deps = [ + ":loop_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + ], +) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 9c544c82bf7f77760e5a2090ca947fd7185e27b7..709a434e40e887502cac1317870eb0db8e0c2910 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -45,45 +45,6 @@ namespace tensorflow { namespace grappler { namespace { -template -bool SafeSetTensorValue(double value, Tensor* tensor) { - using RealType = typename Eigen::NumTraits::Real; - if (value > std::numeric_limits::max() || - value < std::numeric_limits::min()) { - return false; - } - tensor->flat()(0) = static_cast(value); - return true; -} - -#define HANDLE_CASE(DTYPE) \ - case DTYPE: \ - if (!SafeSetTensorValue::Type>( \ - static_cast(value), tensor)) { \ - return errors::InvalidArgument("Cannot store value ", value, \ - " in tensor of type " #DTYPE); \ - } \ - break - -Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { - switch (dtype) { - // HANDLE_CASE(DT_HALF); - HANDLE_CASE(DT_FLOAT); - HANDLE_CASE(DT_DOUBLE); - HANDLE_CASE(DT_UINT8); - HANDLE_CASE(DT_INT8); - HANDLE_CASE(DT_UINT16); - HANDLE_CASE(DT_INT16); - HANDLE_CASE(DT_INT32); - HANDLE_CASE(DT_INT64); - HANDLE_CASE(DT_COMPLEX64); - HANDLE_CASE(DT_COMPLEX128); - default: - return errors::InvalidArgument("Unexpected type ", DataTypeString(dtype)); - } - return Status::OK(); -} - template bool AreInversePermutations(const std::vector& a, const std::vector& b) { if (a.size() != b.size()) { @@ -870,8 +831,13 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } TensorValue value(&t); NodeDef* new_const_node = AddNode(*node, "const", /*copy_node=*/false); - *new_const_node = - ConstantFolding::CreateNodeDef(new_const_node->name(), value); + status = ConstantFolding::CreateNodeDef(new_const_node->name(), value, + new_const_node); + if (!status.ok()) { + LOG(WARNING) << "Failed to create const node: " + << status.error_message(); + return ""; + } new_const_node->set_device(node->device()); nodes_to_simplify->PushBack(new_const_node); @@ -1077,7 +1043,12 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { // consumers of `node` are already redirected to `simplified_tensor`. // Re-push the consumers into `nodes_to_simplify` for further // optimizations. - std::set consumers = node_map_->GetOutputs(node->name()); + const std::set outputs = node_map_->GetOutputs(node->name()); + std::vector consumers(outputs.begin(), outputs.end()); + std::sort(consumers.begin(), consumers.end(), + [](const NodeDef* n1, const NodeDef* n2) { + return n1->name() < n2->name(); + }); for (NodeDef* consumer : consumers) { // Update `consumer`'s use of `node` to `input`'s operand. for (int i = 0; i < consumer->input_size(); ++i) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 1e6f11c8aa06b1115c7b74b25120a9d7b7b4a76c..a5417aaa51b63380eb3228622097cb005f407f96 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -35,7 +35,9 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/setround.h" #include "tensorflow/core/platform/tensor_coding.h" #include "tensorflow/core/public/version.h" #include "tensorflow/core/util/bcast.h" @@ -51,7 +53,14 @@ class EigenThreadPoolWrapper : public Eigen::ThreadPoolInterface { explicit EigenThreadPoolWrapper(thread::ThreadPool* pool) : pool_(pool) {} ~EigenThreadPoolWrapper() override {} void Schedule(std::function fn) override { - pool_->Schedule(std::move(fn)); + auto wrapped = [=]() { + // TensorFlow flushes denormals to zero and rounds to nearest, so we do + // the same here. + port::ScopedFlushDenormal flush; + port::ScopedSetRound round(FE_TONEAREST); + fn(); + }; + pool_->Schedule(std::move(wrapped)); } int NumThreads() const override { return pool_->NumThreads(); } int CurrentThreadId() const override { return pool_->CurrentThreadId(); } @@ -292,16 +301,16 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // graph. const int node_count = graph_->node_size(); for (int i = 0; i < node_count; ++i) { - NodeDef& node = *graph_->mutable_node(i); - const string op = node.op(); + NodeDef* node = graph_->mutable_node(i); + const string op = node->op(); if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") { continue; } const std::vector& output = - properties.GetOutputProperties(node.name()); + properties.GetOutputProperties(node->name()); const std::vector& input = - properties.GetInputProperties(node.name()); + properties.GetInputProperties(node->name()); if (input.empty() || output.empty()) { continue; } @@ -328,35 +337,35 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // could have multiple outputs). if (op == "Shape" || op == "Size" || op == "Rank") { // Replace the node with the corresponding constant. - node.set_op("Const"); - node.clear_attr(); - (*node.mutable_attr())["dtype"].set_type(type); + node->set_op("Const"); + node->clear_attr(); + (*node->mutable_attr())["dtype"].set_type(type); value.AsProtoTensorContent( - (*node.mutable_attr())["value"].mutable_tensor()); + (*node->mutable_attr())["value"].mutable_tensor()); // Turn the data input into a control dependency: this is needed to // ensure that the constant value will only be run in the // cases where the shape/rank/size would have been run in // the original graph. Additional inputs are extra control string ctrl_dep = - AddControlDependency(node.input(0), graph_, node_map_.get()); - node.set_input(0, ctrl_dep); - node_map_->AddOutput(NodeName(ctrl_dep), node.name()); + AddControlDependency(node->input(0), graph_, node_map_.get()); + node->set_input(0, ctrl_dep); + node_map_->AddOutput(NodeName(ctrl_dep), node->name()); } else { - auto outputs = node_map_->GetOutputs(node.name()); + auto outputs = node_map_->GetOutputs(node->name()); for (const auto& output : outputs) { for (int k = 0; k < output->input_size(); ++k) { int port; string node_name = ParseNodeName(output->input(k), &port); - if (node_name == node.name() && port == j) { + if (node_name == node->name() && port == j) { // Create a const node as ShapeN's output if not already. const string const_name = - OptimizedNodeName(node, strings::StrCat("-matshapes-", j)); + OptimizedNodeName(*node, strings::StrCat("-matshapes-", j)); if (node_map_->GetNode(const_name) == nullptr) { NodeDef* added_node = graph_->add_node(); added_node->set_name(const_name); added_node->set_op("Const"); - added_node->set_device(node.device()); + added_node->set_device(node->device()); node_map_->AddNode(added_node->name(), added_node); (*added_node->mutable_attr())["dtype"].set_type(type); value.AsProtoTensorContent( @@ -364,7 +373,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // We add a control dependency to the original ShapeN node, // so that the node will only be run if all inputs of the // original ShapeN node are run. - string ctrl_dep = AddControlDependency(node.name(), graph_, + string ctrl_dep = AddControlDependency(node->name(), graph_, node_map_.get()); *added_node->add_input() = ctrl_dep; node_map_->AddOutput(NodeName(ctrl_dep), added_node->name()); @@ -529,7 +538,8 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs( out[j] = node_map_->GetNode(const_name); if (out[j] == nullptr) { out[j] = graph_->add_node(); - *out[j] = CreateNodeDef(const_name, TensorValue(&value)); + TF_RETURN_IF_ERROR( + CreateNodeDef(const_name, TensorValue(&value), out[j])); out[j]->set_device(node.device()); node_map_->AddNode(const_name, out[j]); string ctrl_dep = @@ -637,7 +647,8 @@ Status ConstantFolding::MaterializeReductionIndices( value.vec()(i) = i; } } - *reduction_indices = CreateNodeDef(const_name, TensorValue(&value)); + TF_RETURN_IF_ERROR( + CreateNodeDef(const_name, TensorValue(&value), reduction_indices)); reduction_indices->set_device(node->device()); string ctrl_dep = AddControlDependency(node->input(1), graph_, node_map_.get()); @@ -677,7 +688,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) { return false; } - // Skip control flow nodes, they can't be folded + // Skip control flow nodes, they can't be folded. if (ModifiesFrameInfo(node)) { return false; } @@ -686,12 +697,16 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } - // Skips ops that don't benefit from folding. - const string& op = node.op(); + // Don't fold stateful ops such as TruncatedNormal. + if (!IsFreeOfSideEffect(node)) { + return false; + } - if (op.find("Placeholder") == 0) { + // Skips ops that don't benefit from folding. + if (IsPlaceholder(node)) { return false; } + const string& op = node.op(); if (op.find("Save") != string::npos || op.find("Restore") != string::npos || op.find("Reader") != string::npos) { return false; @@ -703,16 +718,12 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } - // Don't fold stateful ops such as TruncatedNormal. const OpDef* op_def = nullptr; Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); if (!status.ok()) { return false; } - if (op_def->is_stateful()) { - return false; - } - + // Don't fold ops without outputs. if (op_def->output_arg_size() == 0) { return false; } @@ -777,8 +788,11 @@ Status CreateConstantTensorAttrValue(DataType type, double value, SET_TENSOR_VAL_CASE(DT_FLOAT, float, float); SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double); SET_TENSOR_VAL_CASE(DT_INT64, int64, int64); + SET_TENSOR_VAL_CASE(DT_UINT64, int64, int64); SET_TENSOR_VAL_CASE(DT_INT32, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT32, int32, int); SET_TENSOR_VAL_CASE(DT_INT16, int32, int); + SET_TENSOR_VAL_CASE(DT_UINT16, int32, int); SET_TENSOR_VAL_CASE(DT_INT8, int32, int); SET_TENSOR_VAL_CASE(DT_UINT8, int32, int); SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool); @@ -792,59 +806,74 @@ Status CreateConstantTensorAttrValue(DataType type, double value, } // namespace // static -NodeDef ConstantFolding::CreateNodeDef(const string& name, - const TensorValue& tensor) { - NodeDef node; - node.set_name(name); - node.set_op("Const"); +Status ConstantFolding::CreateNodeDef(const string& name, + const TensorValue& tensor, + NodeDef* node) { + node->set_name(name); + node->set_op("Const"); AttrValue attr_type; attr_type.set_type(tensor->dtype()); - node.mutable_attr()->insert({"dtype", attr_type}); + node->mutable_attr()->insert({"dtype", attr_type}); AttrValue attr_tensor; TensorProto* t = attr_tensor.mutable_tensor(); bool optimized = false; + size_t encoded_size; // Use the packed representation whenever possible to avoid generating large // graphdefs. Moreover, avoid repeating the last values if they're equal. if (tensor->NumElements() > 4) { -#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \ - const TYPE* val_ptr = tensor->flat().data(); \ - TYPE last = *val_ptr; \ - int64 last_index = 0; \ - for (int64 i = 0; i < tensor->NumElements(); ++i) { \ - TYPE cur = *val_ptr++; \ - if (cur != last) { \ - last = cur; \ - last_index = i; \ - } \ - } \ - if (last_index < kint32max) { \ - optimized = true; \ - t->mutable_##NAME##_val()->Reserve(last_index + 1); \ - t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \ - val_ptr = tensor->flat().data(); \ - for (int64 i = 0; i <= last_index; ++i) { \ - t->set_##NAME##_val(i, *val_ptr++); \ - } \ - } - - if (tensor->dtype() == DT_FLOAT) { - POPULATE_TENSOR_PROTO(tensor, t, float, float) - } else if (tensor->dtype() == DT_DOUBLE) { - POPULATE_TENSOR_PROTO(tensor, t, double, double) - } else if (tensor->dtype() == DT_INT64) { - POPULATE_TENSOR_PROTO(tensor, t, int64, int64) - } else if (tensor->dtype() == DT_INT32) { - POPULATE_TENSOR_PROTO(tensor, t, int32, int) - } else if (tensor->dtype() == DT_INT16) { - POPULATE_TENSOR_PROTO(tensor, t, int16, int) - } else if (tensor->dtype() == DT_INT8) { - POPULATE_TENSOR_PROTO(tensor, t, int8, int) - } else if (tensor->dtype() == DT_UINT8) { - POPULATE_TENSOR_PROTO(tensor, t, uint8, int) - } else if (tensor->dtype() == DT_BOOL) { - POPULATE_TENSOR_PROTO(tensor, t, bool, bool) +#define POPULATE_TENSOR_PROTO(tensor, t, TYPE, NAME) \ + { \ + const TYPE* val_ptr = tensor->flat().data(); \ + TYPE last = *val_ptr; \ + int64 last_index = 0; \ + for (int64 i = 0; i < tensor->NumElements(); ++i) { \ + TYPE cur = *val_ptr++; \ + if (cur != last) { \ + last = cur; \ + last_index = i; \ + } \ + } \ + if (last_index < kint32max) { \ + optimized = true; \ + encoded_size = (last_index + 1) * sizeof(NAME); \ + t->mutable_##NAME##_val()->Reserve(last_index + 1); \ + t->mutable_##NAME##_val()->AddNAlreadyReserved(last_index + 1); \ + val_ptr = tensor->flat().data(); \ + for (int64 i = 0; i <= last_index; ++i) { \ + t->set_##NAME##_val(i, *val_ptr++); \ + } \ + } \ + } \ + break + + switch (tensor->dtype()) { + case DT_FLOAT: + POPULATE_TENSOR_PROTO(tensor, t, float, float); + case DT_DOUBLE: + POPULATE_TENSOR_PROTO(tensor, t, double, double); + case DT_INT64: + POPULATE_TENSOR_PROTO(tensor, t, int64, int64); + case DT_UINT64: + POPULATE_TENSOR_PROTO(tensor, t, uint64, int64); + case DT_INT32: + POPULATE_TENSOR_PROTO(tensor, t, int32, int); + case DT_UINT32: + POPULATE_TENSOR_PROTO(tensor, t, uint32, int); + case DT_INT16: + POPULATE_TENSOR_PROTO(tensor, t, int16, int); + case DT_UINT16: + POPULATE_TENSOR_PROTO(tensor, t, uint16, int); + case DT_INT8: + POPULATE_TENSOR_PROTO(tensor, t, int8, int); + case DT_UINT8: + POPULATE_TENSOR_PROTO(tensor, t, uint8, int); + case DT_BOOL: + POPULATE_TENSOR_PROTO(tensor, t, bool, bool); + default: + /* Do nothing. */ + break; } } if (optimized) { @@ -853,9 +882,15 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name, tensor->shape().AsProto(t->mutable_tensor_shape()); } else { tensor->AsProtoTensorContent(t); + encoded_size = t->tensor_content().size(); + } + node->mutable_attr()->insert({"value", attr_tensor}); + + if (encoded_size < 10 * 1024 * 1024) { + return Status::OK(); } - node.mutable_attr()->insert({"value", attr_tensor}); - return node; + return errors::InvalidArgument( + strings::StrCat("Can't fold ", name, ", its size would be too large")); } Status ConstantFolding::EvaluateNode(const NodeDef& node, @@ -929,17 +964,19 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node, return Status(error::INVALID_ARGUMENT, "Expected at least one output."); } + outputs->resize(output_tensors.size()); for (size_t i = 0; i < output_tensors.size(); i++) { string node_name = OptimizedNodeName(node, "-folded"); if (output_tensors.size() > 1) { node_name = strings::StrCat(node_name, "-", i); } if (output_tensors[i].tensor) { - outputs->push_back(CreateNodeDef(node_name, output_tensors[i])); + TF_RETURN_IF_ERROR( + CreateNodeDef(node_name, output_tensors[i], &outputs->at(i))); } else { // Create an empty NodeDef to identify dead outputs (e.g. the output of a // switch that's not selected by the switch predicate). - outputs->push_back(NodeDef()); + outputs->at(i) = NodeDef(); } } return Status::OK(); @@ -1147,9 +1184,8 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { std::unordered_set processed_nodes; std::deque queue; for (int i = 0; i < graph_->node_size(); i++) { - auto node = graph_->mutable_node(i); - if (IsFoldable(*node)) { - queue.push_back(node); + if (IsFoldable(graph_->node(i))) { + queue.push_back(graph_->mutable_node(i)); } } while (!queue.empty()) { @@ -1159,14 +1195,20 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { continue; } // We need to record a copy of output nodes before FoldNode() modifies it. - std::set outputs = node_map_->GetOutputs(node->name()); + // We also need to ensure that the fanout is sorted deterministically. + const std::set& outputs = node_map_->GetOutputs(node->name()); + std::vector fanout(outputs.begin(), outputs.end()); + std::sort(fanout.begin(), fanout.end(), + [](const NodeDef* n1, const NodeDef* n2) { + return n1->name() < n2->name(); + }); Status s = FoldNode(node, output); processed_nodes.insert(node->name()); if (!s.ok()) { VLOG(1) << "Failed to fold node " << node->name() << ": " << s; } else { - for (auto& output : outputs) { + for (auto& output : fanout) { if (IsFoldable(*output)) { queue.push_back(output); } @@ -1178,8 +1220,8 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { int last = output->node_size() - 1; for (int i = output->node_size() - 1; i >= 0; --i) { const NodeDef& node = output->node(i); - auto outputs = node_map_->GetOutputs(node.name()); - if (outputs.empty()) { + auto fanout = node_map_->GetOutputs(node.name()); + if (fanout.empty()) { output->mutable_node()->SwapElements(i, last); last--; } @@ -1191,8 +1233,8 @@ Status ConstantFolding::FoldGraph(GraphDef* output) { // If no fetch nodes is provided, we conservatively // keep all nodes in the original graph in case users need to fetch // their values. - auto outputs = node_map_->GetOutputs(node.name()); - if (!outputs.empty() || !has_fetch_ || + auto fanout = node_map_->GetOutputs(node.name()); + if (!fanout.empty() || !has_fetch_ || nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { auto added_node = output->add_node(); *added_node = node; @@ -1306,14 +1348,14 @@ bool ConstantFolding::IsOnes(const NodeDef& node) const { // IS_ONES_CASE(DT_HALF); IS_ONES_CASE(DT_FLOAT); IS_ONES_CASE(DT_DOUBLE); + IS_ONES_CASE(DT_COMPLEX64); + IS_ONES_CASE(DT_COMPLEX128); IS_ONES_CASE(DT_UINT8); IS_ONES_CASE(DT_INT8); IS_ONES_CASE(DT_UINT16); IS_ONES_CASE(DT_INT16); IS_ONES_CASE(DT_INT32); IS_ONES_CASE(DT_INT64); - IS_ONES_CASE(DT_COMPLEX64); - IS_ONES_CASE(DT_COMPLEX128); default: VLOG(1) << "Unsupported type " << DataTypeString(dtype); return false; @@ -1337,14 +1379,14 @@ bool ConstantFolding::IsZeros(const NodeDef& node) const { // IS_ZEROS_CASE(DT_HALF); IS_ZEROS_CASE(DT_FLOAT); IS_ZEROS_CASE(DT_DOUBLE); + IS_ZEROS_CASE(DT_COMPLEX64); + IS_ZEROS_CASE(DT_COMPLEX128); IS_ZEROS_CASE(DT_UINT8); IS_ZEROS_CASE(DT_INT8); IS_ZEROS_CASE(DT_UINT16); IS_ZEROS_CASE(DT_INT16); IS_ZEROS_CASE(DT_INT32); IS_ZEROS_CASE(DT_INT64); - IS_ZEROS_CASE(DT_COMPLEX64); - IS_ZEROS_CASE(DT_COMPLEX128); default: VLOG(1) << "Unsupported type " << DataTypeString(dtype); return false; @@ -1375,6 +1417,29 @@ void ConstantFolding::ReplaceOperationWithIdentity(int input_to_forward, graph_modified_ = true; } +void ConstantFolding::ReplaceOperationWithSnapshot(int input_to_forward, + NodeDef* node, + GraphDef* graph) { + node->set_op("Snapshot"); + DataType dtype = node->attr().at("T").type(); + node->clear_attr(); + (*node->mutable_attr())["T"].set_type(dtype); + + // Propagate the designated input through the Snapshot. + node->mutable_input()->SwapElements(0, input_to_forward); + // Add all other inputs as control dependencies. + for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + const string ctrl_dep = + AddControlDependency(node->input(i), graph, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); + node->set_input(i, ctrl_dep); + } + graph_modified_ = true; +} + void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph) { node->set_op("Reciprocal"); @@ -1386,6 +1451,17 @@ void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, graph_modified_ = true; } +void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node, + GraphDef* graph) { + node->set_op("Neg"); + node->mutable_input()->SwapElements(0, 1); + const string ctrl_dep = + AddControlDependency(node->input(1), graph, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(1), ctrl_dep); + node->set_input(1, ctrl_dep); + graph_modified_ = true; +} + Status ConstantFolding::ReplaceOperationWithConstant( double value, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { @@ -1417,6 +1493,122 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; for (int i = 0; i < output->node_size(); ++i) { NodeDef* node = output->mutable_node(i); + // Remove Shuffle or Reverse op over scalar values. + if (use_shape_info && + (IsShuffle(*node) || IsReverse(*node) || IsTranspose(*node))) { + const auto& shape = + properties.GetInputProperties(node->name())[0].shape(); + // The node is replaceable iff + // unknown_rank == false && (dim_size == 0 || all dims have size 1) + bool replaceable = !shape.unknown_rank(); + for (int j = 0; j < shape.dim_size(); ++j) { + replaceable &= shape.dim(j).size() == 1; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, node, output); + } + } + + // Switch(x, x) will always feed false to its false branch and true to + // its true branch. By rewriting the graph a bit, we can propagate these + // constants down the two output branches, and just use control dependencies + // to trigger the selected one at runtime. For example, + // + // +------+ + // x-->|Switch|-->a (in practice there may be multiple consumers of each + // x-->| |-->b output branch.) + // +------+ + // + // Is rewritten as + // + // +------+ + // x-->|Switch|-->Identity--^>Const(false)-->a + // x-->| |-->Identity--^>Const(true)-->b + // +------+ + if (node->op() == "Switch" && node->input(0) == node->input(1) && + !OptimizedNodeExists(*node, "_const_false") && + !OptimizedNodeExists(*node, "_const_true")) { + bool already_optimized = true; + // If the optimization was already applied, the switch would have exactly + // one Identity node consuming each of its outputs, each without any + // non-control outputs. + auto fanouts = node_map_->GetOutputs(node->name()); + if (fanouts.size() == 2) { + for (NodeDef* fanout : fanouts) { + if (!IsIdentity(*fanout) || + NumNonControlOutputs(*fanout, *node_map_) > 0) { + already_optimized = false; + break; + } + } + } + Tensor false_t(DT_BOOL, TensorShape({})); + Tensor true_t(DT_BOOL, TensorShape({})); + // Make sure we don't proceed if this switch node was already optimized. + if (!already_optimized && SetTensorValue(DT_BOOL, true, &true_t).ok() && + SetTensorValue(DT_BOOL, false, &false_t).ok()) { + // Copy the set of consumers of the switch as they will be manipulated + // below. + const std::set& consumer_set = + node_map_->GetOutputs(node->name()); + std::vector consumers(consumer_set.begin(), + consumer_set.end()); + std::sort(consumers.begin(), consumers.end(), + [](const NodeDef* n1, const NodeDef* n2) { + return n1->name() < n2->name(); + }); + // Create constant false & true nodes. + NodeDef* false_node = output->add_node(); + false_node->set_name(OptimizedNodeName(*node, "_const_false")); + if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), + false_node) + .ok()) { + continue; + } + false_node->set_device(node->device()); + + NodeDef* true_node = output->add_node(); + true_node->set_name(OptimizedNodeName(*node, "_const_true")); + if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node) + .ok()) { + continue; + } + true_node->set_device(node->device()); + + // Add controls from the switch ports to the constants, and connect the + // constants to the original switch outputs. + const string false_port = node->name(); + const string true_port = strings::StrCat(node->name(), ":1"); + const string false_ctrl_dep = + AddControlDependency(false_port, output, node_map_.get()); + false_node->add_input(false_ctrl_dep); + const string true_ctrl_dep = + AddControlDependency(true_port, output, node_map_.get()); + true_node->add_input(true_ctrl_dep); + + node_map_->AddNode(false_node->name(), false_node); + node_map_->AddNode(true_node->name(), true_node); + node_map_->AddOutput(NodeName(false_ctrl_dep), false_node->name()); + node_map_->AddOutput(NodeName(true_ctrl_dep), true_node->name()); + + for (NodeDef* consumer : consumers) { + for (int i = 0; i < consumer->input_size(); ++i) { + const string& input = consumer->input(i); + if (input == false_port) { + consumer->set_input(i, false_node->name()); + node_map_->UpdateInput(consumer->name(), false_port, + false_node->name()); + } else if (input == true_port) { + consumer->set_input(i, true_node->name()); + node_map_->UpdateInput(consumer->name(), true_port, + true_node->name()); + } + } + } + graph_modified_ = true; + continue; + } + } if (IsSimplifiableReduction(*node)) { // Replace the reduction node with an identity node, that can be further // optimized by the model pruner. @@ -1443,15 +1635,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, graph_modified_ = true; continue; } - const bool safe_to_use_shapes = - use_shape_info && (feed_nodes_.empty() || is_aggressive); + const bool is_mul = IsMul(*node); const bool is_matmul = IsMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node); const bool is_sub = IsSub(*node); const bool is_any_div = IsAnyDiv(*node); // Simplify arithmetic operations with ones or zeros. - if (safe_to_use_shapes && + if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_any_div) && properties.HasInputProperties(node->name()) && properties.HasOutputProperties(node->name())) { @@ -1473,9 +1664,14 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { - // TODO(rmlarsen): Handle subtraction 0 - y. // 1 * y = y or 0 + y = y. - ReplaceOperationWithIdentity(1, node, output); + ReplaceOperationWithSnapshot(1, node, output); + continue; + } + + if (y_matches_output_shape && (is_sub && x_is_zero)) { + // Replace 0 - y with Neg(y). + ReplaceSubtractionFromZeroByNegation(node, output); continue; } @@ -1493,11 +1689,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, const bool y_is_zero = IsZeros(*y); const bool y_is_one = IsOnes(*y); const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); - if (x_matches_output_shape && - (((is_mul || is_any_div) && y_is_one) || - ((is_add || is_sub) && y_is_zero && is_aggressive))) { + if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || + ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithIdentity(0, node, output); + ReplaceOperationWithSnapshot(0, node, output); continue; } @@ -1547,8 +1742,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, } // Insert new reciprocal op and change node from Div to Mul. NodeDef* reciprocal_node = output->add_node(); - reciprocal_node->set_name(AddPrefixToNodeName( - strings::StrCat(node->name(), "_recip"), kConstantFoldingConst)); + reciprocal_node->set_name(OptimizedNodeName(*node, "_recip")); reciprocal_node->set_op("Reciprocal"); reciprocal_node->set_device(node->device()); node->set_op("Mul"); @@ -1647,6 +1841,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output, graph_modified_ = true; } } + return Status::OK(); } @@ -1685,11 +1880,17 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, TF_RETURN_IF_ERROR(FoldGraph(output)); node_map_.reset(new NodeMap(output)); TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info)); + return Status::OK(); } Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* output) { + // TensorFlow flushes denormals to zero and rounds to nearest, so we do + // the same here. + port::ScopedFlushDenormal flush; + port::ScopedSetRound round(FE_TONEAREST); + nodes_to_preserve_ = item.NodesToPreserve(); for (const auto& feed : item.feed) { feed_nodes_.insert(NodeName(feed.first)); @@ -1724,5 +1925,5 @@ void ConstantFolding::Feedback(Cluster* cluster, const GrapplerItem& item, // Nothing to do for ConstantFolding. } -} // end namespace grappler -} // end namespace tensorflow +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 18acc91e8a18f4bf2eb77c7e5171eaca4ff5bec5..2fd59c7f9ccf3f94e683d7ec41a5848b9eec4a8f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -33,7 +33,8 @@ const char kConstantFoldingCtrl[] = "ConstantFoldingCtrl"; // Constant folding optimization for a graph. class ConstantFolding : public GraphOptimizer { public: - static NodeDef CreateNodeDef(const string& name, const TensorValue& tensor); + static Status CreateNodeDef(const string& name, const TensorValue& tensor, + NodeDef* node); static string AddControlDependency(const string& input_name, GraphDef* graph, NodeMap* node_map); @@ -79,6 +80,9 @@ class ConstantFolding : public GraphOptimizer { bool IsZeros(const NodeDef& node) const; void ReplaceOperationWithIdentity(int input_to_forward, NodeDef* node, GraphDef* graph); + void ReplaceOperationWithSnapshot(int input_to_forward, NodeDef* node, + GraphDef* graph); + void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); Status ReplaceOperationWithConstant(double value, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 46998dcc91c8df2313ff92b056f732379b173661..c6540192d7f85098f64ba42c0d4bf27dafc762ab 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -195,8 +195,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"addn", "matmul3", "matmul4"}; - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -214,11 +213,11 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("^zeros", node.input(0)); EXPECT_EQ("^y", node.input(1)); } else if (name == "mul3") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^ones", node.input(1)); } else if (name == "mul4") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("^ones", node.input(1)); } else if (name == "mul5") { @@ -230,7 +229,7 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("^zeros_1d", node.input(0)); EXPECT_EQ("^y", node.input(1)); } else if (name == "div1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^ones", node.input(1)); } else if (name == "div2") { @@ -266,15 +265,15 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ(2, t.tensor_shape().dim(0).size()); EXPECT_EQ(3, t.tensor_shape().dim(1).size()); } else if (name == "add1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "add2") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "bias_add1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros_1d", node.input(1)); } else if (name == "bias_add2") { @@ -283,14 +282,13 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("zeros", node.input(0)); EXPECT_EQ("bias", node.input(1)); } else if (name == "sub1") { - EXPECT_EQ("Identity", node.op()); + EXPECT_EQ("Snapshot", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros", node.input(1)); } else if (name == "sub2") { - // We don't handle this case yet. - EXPECT_EQ("Sub", node.op()); - EXPECT_EQ("zeros", node.input(0)); - EXPECT_EQ("y", node.input(1)); + EXPECT_EQ("Neg", node.op()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^zeros", node.input(1)); } const std::set square_zero_const{"mul1", "mul2", "mul5", "mul6", "matmul1", "matmul2"}; @@ -322,8 +320,7 @@ TEST_F(ConstantFoldingTest, StrengthReduce_Reciprocal) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"div_f", "div_i", "realdiv"}; - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -413,8 +410,7 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -468,12 +464,10 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, - nullptr /* cpu_device */); + ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - LOG(INFO) << output.DebugString(); EXPECT_EQ(10, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { @@ -995,8 +989,10 @@ TEST_F(ConstantFoldingTest, SwitchNodesEmptyFetch) { EXPECT_EQ(present_nodes.size(), output.node_size()); int found = 0; for (const auto& node : output.node()) { - EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end()); - EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end()); + EXPECT_TRUE(present_nodes.find(node.name()) != present_nodes.end()) + << node.name(); + EXPECT_TRUE(not_present_nodes.find(node.name()) == not_present_nodes.end()) + << node.name(); present_nodes.erase(node.name()); not_present_nodes.erase(node.name()); if (node.name() == "rank") { @@ -1181,8 +1177,43 @@ TEST_F(ConstantFoldingTest, MergeNodes) { EXPECT_EQ(2, out_idx.flat()(0)); } +TEST_F(ConstantFoldingTest, ShuffleReverseOnScalarRemoval) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = + ops::Variable(scope.WithOpName("in1"), TensorShape({}), DT_FLOAT); + Output in2 = + ops::Variable(scope.WithOpName("in2"), TensorShape({}), DT_FLOAT); + ops::RandomShuffle s1(scope.WithOpName("s1"), in1); + ops::RandomShuffle s2(scope.WithOpName("s2").WithControlDependencies({in1}), + in2); + + ops::Add out1(scope.WithOpName("out1"), s1, s2); + ops::Identity out2(scope.WithOpName("out2"), s2); + + GrapplerItem item; + item.fetch = {"out1", "out2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef got; + Status status = fold.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, &want); + AddNode("in2", "VariableV2", {}, &want); + AddNode("s1", "Identity", {"in1"}, &want); + AddNode("s2", "Identity", {"in2", AsControlDependency("in1")}, &want); + AddNode("out1", "Add", {"s1", "s2"}, &want); + AddNode("out2", "Identity", {"s2"}, &want); + + CompareGraphs(want, got); +} + TEST_F(ConstantFoldingTest, NoOpReduction) { - // Build a simple graph with a reduction that can be reduced to the identity. + // Build a simple graph with a reduction that can be reduced to the + // identity. tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); Output v = ops::Variable(scope.WithOpName("v"), {3, 5, 7}, DT_FLOAT); @@ -1308,8 +1339,8 @@ TEST_F(ConstantFoldingTest, Packing) { TF_EXPECT_OK(status); // Make sure that the representation of the folded constant is space - // efficient: in particular, the whole message should be smaller than 8k (the - // size needed to naively encode 1000 floats folded twice). + // efficient: in particular, the whole message should be smaller than 8k + // (the size needed to naively encode 1000 floats folded twice). EXPECT_GT(8000, output.ByteSizeLong()); } @@ -1337,7 +1368,7 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) { GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -1398,7 +1429,7 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch.push_back("reshape"); - ConstantFolding fold(RewriterConfig::AGGRESSIVE, nullptr /* cpu_device */); + ConstantFolding fold(nullptr /* cpu_device */); GraphDef output; Status status = fold.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); @@ -1426,6 +1457,96 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { EXPECT_EQ(3, found); } +TEST_F(ConstantFoldingTest, LargeConstant) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + // Generate a 4k by 4k constant matrix. + Output mat_diag = + ops::Const(scope.WithOpName("mat_diag"), 3.14f, TensorShape({1024 * 4})); + Output mat = ops::Diag(scope.WithOpName("mat"), mat_diag); + Output out = ops::Identity(scope.WithOpName("out"), mat); + + GrapplerItem item; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + item.fetch.push_back("out"); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // Make sure the diag node hasn't been folded, since it would use too much + // memory to encode the corresponding constant. + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "out") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("mat", node.input(0)); + ++found; + } else if (node.name() == "mat") { + EXPECT_EQ("Diag", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("mat_diag", node.input(0)); + ++found; + } + } + EXPECT_EQ(2, found); + + EXPECT_GT(1024 * 1024, output.ByteSizeLong()); +} + +TEST_F(ConstantFoldingTest, SwitchIdenticalInputs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = ops::Placeholder(s.WithOpName("x"), DT_BOOL, + ops::Placeholder::Shape(TensorShape({}))); + ops::Switch sw = ops::Switch(s.WithOpName("switch"), x, x); + Output id_false = ops::LogicalNot(s.WithOpName("id_false"), sw.output_false); + Output id_true = ops::LogicalNot(s.WithOpName("id_true"), sw.output_true); + + GrapplerItem item; + item.fetch.push_back("id_false"); + item.fetch.push_back("id_true"); + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ConstantFolding fold(nullptr /* cpu_device */); + GraphDef output; + Status status = fold.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + EXPECT_EQ(6, output.node_size()); + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "switch" || node.name() == "x") { + ++found; + } + if (node.name() == "id_false") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^ConstantFoldingCtrl/switch_0", node.input(0)); + ++found; + } + if (node.name() == "id_true") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("^ConstantFoldingCtrl/switch_1", node.input(0)); + ++found; + } + if (node.name() == "ConstantFoldingCtrl/switch_0") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("switch", node.input(0)); + ++found; + } + if (node.name() == "ConstantFoldingCtrl/switch_1") { + EXPECT_EQ("Identity", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("switch:1", node.input(0)); + ++found; + } + } + EXPECT_EQ(6, found); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h similarity index 53% rename from tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc rename to tensorflow/core/grappler/optimizers/custom_graph_optimizer.h index 1d5b5c2c1e3bd27e6a6006aeb0c35f703e288e11..a80d46f416d8c1f43c46c3183f19e4e582dec8ec 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer.h @@ -13,28 +13,23 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" +#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ +#define TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ -#define EIGEN_USE_THREADS +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" -#include "third_party/eigen3/Eigen/Core" +namespace tensorflow { +namespace grappler { -#ifdef TF_XLA_HAS_SSE4_1 +// A custom optimizer that can be registered. +class CustomGraphOptimizer : public GraphOptimizer { + public: + virtual ~CustomGraphOptimizer() {} + virtual Status Init() = 0; +}; -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::plog(p); -} +} // end namespace grappler +} // end namespace tensorflow -#endif // TF_XLA_HAS_SSE4_1 - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE"; - -} // namespace runtime -} // namespace cpu -} // namespace xla +#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..6eed43c2b132c02b58a0088c30dd5648fe80d212 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc @@ -0,0 +1,61 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +#include +#include + +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { + +namespace { +typedef std::unordered_map + RegistrationMap; +RegistrationMap* registered_optimizers = nullptr; +RegistrationMap* GetRegistrationMap() { + if (registered_optimizers == nullptr) + registered_optimizers = new RegistrationMap; + return registered_optimizers; +} +} // namespace + +std::unique_ptr +CustomGraphOptimizerRegistry::CreateByNameOrNull(const string& name) { + const auto it = GetRegistrationMap()->find(name); + if (it == GetRegistrationMap()->end()) return nullptr; + return std::unique_ptr(it->second()); +} + +std::vector CustomGraphOptimizerRegistry::GetRegisteredOptimizers() { + std::vector optimizer_names; + optimizer_names.reserve(GetRegistrationMap()->size()); + for (const auto& opt : *GetRegistrationMap()) + optimizer_names.emplace_back(opt.first); + return optimizer_names; +} + +void CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + const Creator& optimizer_creator, const string& name) { + const auto it = GetRegistrationMap()->find(name); + if (it != GetRegistrationMap()->end()) { + LOG(FATAL) << "CustomGraphOptimizer is registered twice: " << name; + } + GetRegistrationMap()->insert({name, optimizer_creator}); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..796da913737b9db1fe4e5cb00b235bf0f5979593 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ + +#include +#include +#include +#include + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +class CustomGraphOptimizerRegistry { + public: + static std::unique_ptr CreateByNameOrNull( + const string& name); + + static std::vector GetRegisteredOptimizers(); + + typedef std::function Creator; + // Regsiter graph optimizer which can be called during program initialization. + // This class is not thread-safe. + static void RegisterOptimizerOrDie(const Creator& optimizer_creator, + const string& name); +}; + +class CustomGraphOptimizerRegistrar { + public: + explicit CustomGraphOptimizerRegistrar( + const CustomGraphOptimizerRegistry::Creator& creator, + const string& name) { + CustomGraphOptimizerRegistry::RegisterOptimizerOrDie(creator, name); + } +}; + +#define REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, name) \ + namespace { \ + static CustomGraphOptimizerRegistrar \ + MyCustomGraphOptimizerClass##_registrar( \ + []() { return new MyCustomGraphOptimizerClass; }, (name)); \ + } // namespace + +#define REGISTER_GRAPH_OPTIMIZER(MyCustomGraphOptimizerClass) \ + REGISTER_GRAPH_OPTIMIZER_AS(MyCustomGraphOptimizerClass, \ + #MyCustomGraphOptimizerClass) + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_CUSTOM_GRAPH_OPTIMIZER_REGISTRY_H_ diff --git a/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..629f5e83c12e91a7cc0f68dc9993e0f7c0117d3c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry_test.cc @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" + +#include +#include +#include +#include + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +static const char* kTestOptimizerName = "Test"; + +class TestGraphOptimizer : public CustomGraphOptimizer { + public: + Status Init() override { return Status::OK(); } + string name() const override { return kTestOptimizerName; } + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + return Status::OK(); + } + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} +}; + +REGISTER_GRAPH_OPTIMIZER_AS(TestGraphOptimizer, "StaticRegister"); + +TEST(CustomGraphOptimizerRegistryTest, DynamicRegistration) { + std::vector optimizers = + CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); + std::unique_ptr test_optimizer; + ASSERT_EQ( + 0, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister")); + test_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister"); + EXPECT_EQ(nullptr, test_optimizer); + CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + []() { return new TestGraphOptimizer; }, "DynamicRegister"); + optimizers = CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); + ASSERT_EQ( + 1, std::count(optimizers.begin(), optimizers.end(), "DynamicRegister")); + test_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull("DynamicRegister"); + ASSERT_NE(nullptr, test_optimizer); + EXPECT_EQ(kTestOptimizerName, test_optimizer->name()); +} + +TEST(CustomGraphOptimizerRegistryTest, StaticRegistration) { + const std::vector optimizers = + CustomGraphOptimizerRegistry::GetRegisteredOptimizers(); + EXPECT_EQ(1, + std::count(optimizers.begin(), optimizers.end(), "StaticRegister")); + std::unique_ptr test_optimizer = + CustomGraphOptimizerRegistry::CreateByNameOrNull("StaticRegister"); + ASSERT_NE(nullptr, test_optimizer); + EXPECT_EQ(kTestOptimizerName, test_optimizer->name()); +} + +TEST(GraphOptimizerRegistryTest, CrashesOnDuplicateRegistration) { + const auto creator = []() { return new TestGraphOptimizer; }; + EXPECT_DEATH(CustomGraphOptimizerRegistry::RegisterOptimizerOrDie( + creator, "StaticRegister"), + "twice"); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..1820af6844215475d2bfccba93891a52029218b2 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/gpu_swapping_kernels.cc @@ -0,0 +1,88 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Op kernels used to swap data in and out of GPU memory. + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +class CopyFromGpuToHostKernel : public AsyncOpKernel { + public: + explicit CopyFromGpuToHostKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& input = ctx->input(0); + OP_REQUIRES_ASYNC( + ctx, !ctx->input_alloc_attr(0).on_host(), + errors::Internal("The input tensor to the _CopyFromGpuToHost kernel " + "must reside on the device."), + done); + + AllocatorAttributes alloc_attrs; + alloc_attrs.set_gpu_compatible(true); + alloc_attrs.set_on_host(true); + Tensor* output; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->allocate_output(0, input.shape(), &output, alloc_attrs), + done); + + ctx->op_device_context()->CopyDeviceTensorToCPU( + &input, "CopyFromGpuToHost", static_cast(ctx->device()), + output, [ctx, done](const Status& s) { + ctx->SetStatus(s); + done(); + }); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("_CopyFromGpuToHost").Device(DEVICE_GPU).HostMemory("output"), + CopyFromGpuToHostKernel); + +class CopyFromHostToGpuKernel : public AsyncOpKernel { + public: + explicit CopyFromHostToGpuKernel(OpKernelConstruction* context) + : AsyncOpKernel(context) {} + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor& input = ctx->input(0); + OP_REQUIRES_ASYNC( + ctx, ctx->input_alloc_attr(0).on_host(), + errors::Internal("The input tensor to the _CopyFromHostToGpu kernel " + "must reside on the host."), + done); + + Tensor* output; + OP_REQUIRES_OK_ASYNC(ctx, ctx->allocate_output(0, input.shape(), &output), + done); + + ctx->op_device_context()->CopyCPUTensorToDevice( + &input, static_cast(ctx->device()), output, + [ctx, done](const Status& s) { + ctx->SetStatus(s); + done(); + }); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("_CopyFromHostToGpu").Device(DEVICE_GPU).HostMemory("input"), + CopyFromHostToGpuKernel); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc b/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..46828346da608a237528da2a2a8070c57946f762 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/gpu_swapping_ops.cc @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Definition for the ops used to swap data in and out of GPU memory. + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace { + +// The _CopyFromGpuToHost op copies its input tensor to the host. The input must +// reside on GPU. The op itself must be placed on GPU. +REGISTER_OP("_CopyFromGpuToHost") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } + return Status::OK(); + }) + .Doc("Copies the input tensor from gpu to the host."); + +// The _CopyFromHostToGpu op copies its input tensor from the host to the GPU. +// The input must reside on CPU. The op itself must be placed on GPU. +REGISTER_OP("_CopyFromHostToGpu") + .Input("input: T") + .Output("output: T") + .Attr("T: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + auto* handle_data = c->input_handle_shapes_and_types(0); + if (handle_data != nullptr) { + c->set_output_handle_shapes_and_types(0, *handle_data); + } + return Status::OK(); + }) + .Doc("Copies the input tensor from the host to the GPU."); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 4342179176b10fbafbe1623c012ea8913212b8f6..826f00209b15705f2a9b8b43f78134498a19d167 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -1717,13 +1717,28 @@ class SqueezeProcessor : public AgnosticNodeProcessor { protected: bool ShouldProcess() const override { - return !MustPreserve() && IsPortZeroDimsN(*node_, 2) && HasOutputs() && - IsNodeAfterNCHWToNHWC() && IsInputConvertible() && IsAlongDimHW() && - IsOnGPU(); + bool is_dims_supported = (IsPortZeroDimsN(*node_, 2) && IsAlongHW()) || + (IsPortZeroDimsN(*node_, 1) && IsAlongNHW()); + return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() && + IsInputConvertible() && is_dims_supported && IsOnGPU(); } Status AddLayoutTransposeToOutputs() override { return Status::OK(); } + Status CustomizedProcessing() override { + TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims")); + auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list(); + if (list->i_size() == 2) { + list->set_i(0, 2); + list->set_i(1, 3); + } else if (list->i_size() == 3) { + list->set_i(1, 2); + list->set_i(2, 3); + } + return Status::OK(); + } + + private: bool IsInputConvertible() const { int input_port; auto input = node_map_->GetNode(node_->input(0)); @@ -1736,33 +1751,31 @@ class SqueezeProcessor : public AgnosticNodeProcessor { if (shape.dim(1).size() == 1 && shape.dim(2).size() == 1) { return true; } + if (shape.dim(0).size() == 1 && shape.dim(1).size() == 1 && + shape.dim(2).size() == 1) { + return true; + } } return false; } - bool IsAlongDimHW() const { + bool IsAlongAxis(const std::vector& axis) const { if (node_->attr().find("squeeze_dims") != node_->attr().end()) { auto list = node_->attr().at("squeeze_dims").list(); // If list is empty, Squeeze op will squeeze all dimensions of size 1. if (list.i_size() == 0) return true; - if (list.i_size() == 2) { - if (list.i(0) == 1 && list.i(1) == 2) { - return true; + if (list.i_size() == axis.size()) { + bool along_axis = true; + for (int i = 0; i < axis.size(); i++) { + along_axis = along_axis && (list.i(i) == axis[i]); } + if (along_axis) return true; } } return false; } - - Status CustomizedProcessing() override { - TF_RETURN_IF_ERROR(HasAttribute(*node_, "squeeze_dims")); - auto list = node_->mutable_attr()->at("squeeze_dims").mutable_list(); - if (list->i_size() == 2) { - list->set_i(0, 2); - list->set_i(1, 3); - } - return Status::OK(); - } + bool IsAlongHW() const { return IsAlongAxis({1, 2}); } + bool IsAlongNHW() const { return IsAlongAxis({0, 1, 2}); } }; class ReduceProcessor : public AgnosticNodeProcessor { @@ -1781,7 +1794,7 @@ class ReduceProcessor : public AgnosticNodeProcessor { } Status CustomizedProcessing() override { - if (IsAlongNHW() || IsAlongHW() || IsAlongC()) { + if (IsReduceAxisSupported()) { DataType dtype = node_->attr().at("Tidx").type(); TF_RETURN_IF_ERROR( UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype)); @@ -1790,17 +1803,17 @@ class ReduceProcessor : public AgnosticNodeProcessor { } Status AddLayoutTransposeToOutputs() override { - if ((IsAlongNHW() || IsAlongHW() || IsAlongC()) && KeepDims()) { - return AgnosticNodeProcessor::AddLayoutTransposeToOutputs(); - } else { - return Status::OK(); + if (KeepDims()) { + return AddTransformToOutputs("Transpose"); } + return Status::OK(); } private: bool IsReduceAxisSupported() const { - return IsAlongAllFourDims() || IsAlongHWC() || - IsAlongNHW() || IsAlongHW() || IsAlongC(); + return KeepDims() || ((IsAlongAllFourDims() || IsAlongHWC() || + IsAlongNHW() || IsAlongHW() || IsAlongC()) && + !KeepDims()); } bool IsAlongAxis(const std::vector& axis) const { diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc new file mode 100644 index 0000000000000000000000000000000000000000..102526e22f4742cb90757a1daf55467dd16afc3e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" + +#include +#include + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace grappler { + +Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + + return Status::OK(); +} + +void LoopOptimizer::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/, + const GraphDef& /*optimized_graph*/, + double /*result*/) { + // Nothing to do for LoopOptimizer. +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.h b/tensorflow/core/grappler/optimizers/loop_optimizer.h new file mode 100644 index 0000000000000000000000000000000000000000..106d4628ae68f3c92ab597f903f96a6af8a64b8d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/loop_optimizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ + +#include +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +class LoopOptimizer : public GraphOptimizer { + public: + LoopOptimizer() : opt_level_(RewriterConfig::ON) {} + explicit LoopOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + ~LoopOptimizer() override {} + + string name() const override { return "loop_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; + + private: + RewriterConfig::Toggle opt_level_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_LOOP_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c09434f60916b9bf269b0f5006b8a3732afaa5fc --- /dev/null +++ b/tensorflow/core/grappler/optimizers/loop_optimizer_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class LoopOptimizerTest : public ::testing::Test {}; + +void VerifyGraphsEqual(const GraphDef& original_graph, + const GraphDef& optimized_graph, const string& func) { + EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << func; + for (int i = 0; i < original_graph.node_size(); ++i) { + const NodeDef& original = original_graph.node(i); + const NodeDef& optimized = optimized_graph.node(i); + EXPECT_EQ(original.name(), optimized.name()) << func; + EXPECT_EQ(original.op(), optimized.op()) << func; + EXPECT_EQ(original.input_size(), optimized.input_size()) << func; + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)) << func; + } + } +} + +TEST_F(LoopOptimizerTest, NoOp) { + // This trivial graph is so basic there's nothing to optimize. + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + LoopOptimizer optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + VerifyGraphsEqual(item.graph, output, __FUNCTION__); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 9f3e94052f8289f575abdc6326cf2ee31749a8e6..694139fa5033410375fcfae2f1141c82fa9d550c 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -490,12 +490,12 @@ void RecomputationRewritingPass(RewriterConfig::MemOptType optimization_level, } bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { - // Look for AddN nodes and record input names. + // Look for AddN nodes (and equivalent) and record input names. GraphView view(&item->graph); std::unordered_map> addn_list; for (NodeDef& node : *item->graph.mutable_node()) { - if (!IsAddN(node)) { + if (!IsAddN(node) && node.op() != "AccumulateNV2") { continue; } // There is nothing to gain by optimizing nodes with 2 or fewer inputs. @@ -511,6 +511,10 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { } } + if (addn_list.empty()) { + return false; + } + GraphMemory memory(*item); const std::unordered_map& devices = cluster->GetDevices(); @@ -560,6 +564,13 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { VLOG(1) << "Missing properties for " << node->name(); continue; } + const TensorShapeProto& shape = + properties.GetOutputProperties(node->name())[0].shape(); + PartialTensorShape shp(shape); + if (!shp.IsFullyDefined()) { + VLOG(1) << "Shape not fully known for " << node->name(); + continue; + } // Compute a topological ordering for the node fanin. std::unordered_map topo_order; @@ -608,8 +619,6 @@ bool SchedulingPass(Cluster* cluster, GrapplerItem* item) { } } - const TensorShapeProto& shape = - properties.GetOutputProperties(node->name())[0].shape(); DataType dtype = node->attr().at("T").type(); const string& device = node->device(); @@ -711,17 +720,19 @@ Status BuildSwapPair(NodeDef* node, int input_to_swap, // Force the tensor to be copied to cpu. NodeDef* swap_out_node = graph->add_node(); swap_out_node->set_name(swap_out_name); - swap_out_node->set_op("Identity"); - swap_out_node->set_device("/device:CPU:0"); + swap_out_node->set_op("_CopyFromGpuToHost"); // Force the tensor to be restored to the device. NodeDef* swap_in_node = graph->add_node(); swap_in_node->set_name(swap_in_name); - swap_in_node->set_op("Identity"); + swap_in_node->set_op("_CopyFromHostToGpu"); *swap_in_node->add_input() = swap_out_node->name(); - // Colocate the swap_in_ node with the node itself. + // Colocate the swap_out_ and swap_in_ nodes with the node itself. + swap_out_node->set_device(node->device()); + swap_in_node->set_device(node->device()); string coloc_group = strings::StrCat("loc@", tensor_to_swap); + (*swap_out_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); (*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group); @@ -1094,7 +1105,8 @@ bool SwappingPass(RewriterConfig::MemOptType optimization_level, Cluster* cluster, GrapplerItem* item, std::unordered_set* skip_list) { std::unordered_map nodes_to_swap; - if (optimization_level == RewriterConfig::SWAPPING_HEURISTICS || + if (optimization_level == RewriterConfig::DEFAULT_MEM_OPT || + optimization_level == RewriterConfig::SWAPPING_HEURISTICS || optimization_level == RewriterConfig::HEURISTICS) { // Use heuristics to figure out what needs to be swapped; IdentifySwappingCandidates(cluster, item, skip_list, &nodes_to_swap); @@ -1223,13 +1235,15 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, bool updated_graph = true; for (int i = 0; i < 25 && updated_graph; ++i) { updated_graph = false; - if ((optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || + if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT || + optimization_level_ == RewriterConfig::SCHEDULING_HEURISTICS || optimization_level_ == RewriterConfig::HEURISTICS) && cluster != nullptr) { updated_graph |= SchedulingPass(cluster, &optimized_item); } - if ((optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS || + if ((optimization_level_ == RewriterConfig::DEFAULT_MEM_OPT || + optimization_level_ == RewriterConfig::SWAPPING_HEURISTICS || optimization_level_ == RewriterConfig::HEURISTICS || optimization_level_ == RewriterConfig::MANUAL) && cluster != nullptr) { diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc index 5d7913e0c018ecf14cc09ab91d3a71125c720aa5..9595936e9e6158045a13ebede95d63b9291ca434 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer_test.cc @@ -221,16 +221,20 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) { // Build a simple graph with an op that's marked for swapping. tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output a = ops::Variable(s.WithOpName("a"), {10, 10}, DT_FLOAT); - Output b = ops::AddN(s.WithOpName("b"), {a}); - Output c = ops::AddN(s.WithOpName("c"), {b}); - Output d = ops::AddN(s.WithOpName("d"), {c}); - Output e = ops::AddN(s.WithOpName("e"), {b, d}); + Output a = + ops::Variable(s.WithOpName("a").WithDevice("/gpu:0"), {10, 10}, DT_FLOAT); + Output b = ops::AddN(s.WithOpName("b").WithDevice("/gpu:0"), {a}); + Output c = ops::AddN(s.WithOpName("c").WithDevice("/gpu:0"), {b}); + Output d = ops::AddN(s.WithOpName("d").WithDevice("/gpu:0"), {c}); + Output e = ops::AddN(s.WithOpName("e").WithDevice("/gpu:0"), {b, d}); + + Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {10, 10}); + Output init = ops::Assign(s.WithOpName("init"), a, constant); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - EXPECT_EQ(5, item.graph.node_size()); + EXPECT_EQ(7, item.graph.node_size()); EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name()); AttrValue& val = (*item.graph.mutable_node(4)->mutable_attr())["_swap_to_host"]; @@ -243,32 +247,43 @@ TEST_F(MemoryOptimizerTest, SimpleSwapping) { Status status = optimizer.Optimize(cluster.get(), item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(7, output.node_size()); - const NodeDef& new_e = output.node(4); + EXPECT_EQ(9, output.node_size()); + const NodeDef& new_e = output.node(6); EXPECT_EQ(NodeName(e.name()), new_e.name()); EXPECT_EQ(2, new_e.input_size()); EXPECT_EQ(NodeName(d.name()), new_e.input(1)); EXPECT_EQ("swap_in_e_0", new_e.input(0)); - const NodeDef& swap_out = output.node(5); + const NodeDef& swap_out = output.node(7); EXPECT_EQ("swap_out_e_0", swap_out.name()); + EXPECT_EQ("_CopyFromGpuToHost", swap_out.op()); - const NodeDef& swap_in = output.node(6); + const NodeDef& swap_in = output.node(8); EXPECT_EQ("swap_in_e_0", swap_in.name()); + EXPECT_EQ("_CopyFromHostToGpu", swap_in.op()); EXPECT_EQ(NodeName(b.name()), swap_out.input(0)); EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0)); EXPECT_EQ("^c", swap_in.input(1)); - const NodeDef& new_c = output.node(2); + const NodeDef& new_c = output.node(4); EXPECT_EQ(NodeName(c.name()), new_c.name()); EXPECT_EQ("^swap_out_e_0", new_c.input(1)); // Run the optimizer a second time to ensure it's idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(cluster.get(), item, &output); + GrapplerItem item_copy(item, std::move(output)); + status = optimizer.Optimize(cluster.get(), item_copy, &output); TF_EXPECT_OK(status); + +#if GOOGLE_CUDA + item.fetch = {"e"}; + item.init_ops = {init.name()}; + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); +#endif } TEST_F(MemoryOptimizerTest, SwappingHeuristics) { @@ -287,9 +302,13 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) { Output h = ops::Exp(s.WithOpName("h").WithDevice("/gpu:0"), c); Output i = ops::Log(s.WithOpName("i").WithDevice("/gpu:0"), d); + Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8}); + Output init = ops::Assign(s.WithOpName("init"), v, constant); + GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"e", "f", "g", "h", "i"}; + item.init_ops = {init.name()}; std::unique_ptr cluster(CreateVirtualCluster()); @@ -308,6 +327,15 @@ TEST_F(MemoryOptimizerTest, SwappingHeuristics) { EXPECT_EQ("axis", node.input(4)); } } + +#if GOOGLE_CUDA + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + for (int i = 0; i < item.fetch.size(); ++i) { + test::ExpectTensorEqual(tensors_expected[i], tensors[i]); + } +#endif } TEST_F(MemoryOptimizerTest, UnswappableInputs) { @@ -325,9 +353,13 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { Output e = ops::Concat(s.WithOpName("e").WithDevice("/gpu:0"), {b, c, d}, axis); + Output constant = ops::Const(s.WithOpName("constant"), 0.0f, {128, 128, 8}); + Output init = ops::Assign(s.WithOpName("init"), v, constant); + GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); item.fetch = {"e"}; + item.init_ops = {init.name()}; std::unique_ptr cluster(CreateVirtualCluster()); @@ -344,6 +376,13 @@ TEST_F(MemoryOptimizerTest, UnswappableInputs) { EXPECT_EQ("^swap_out_d_2", node.input(4)); } } + +#if GOOGLE_CUDA + auto tensors_expected = EvaluateFetchNodes(item); + GrapplerItem optimized(item, std::move(output)); + auto tensors = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual(tensors_expected[0], tensors[0]); +#endif } TEST_F(MemoryOptimizerTest, AccumulationRewrites) { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 6d93f741861d6b1cb3731007d7be6b1f5a598c61..7ae77207afc5be86a99bf8145025e0f18ef4af0f 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -19,9 +19,11 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" +#include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils/topological_sort.h" @@ -75,6 +77,9 @@ std::unique_ptr MetaOptimizer::NewOptimizer( graph_optimizer.reset( new DependencyOptimizer(cfg_.dependency_optimization())); } + if (optimizer == "loop") { + graph_optimizer.reset(new LoopOptimizer(cfg_.loop_optimization())); + } return graph_optimizer; } @@ -97,11 +102,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, optimizers.push_back(std::unique_ptr( new DependencyOptimizer(cfg_.dependency_optimization()))); } + if (cfg_.loop_optimization() != RewriterConfig::OFF) { + optimizers.push_back(std::unique_ptr( + new LoopOptimizer(cfg_.loop_optimization()))); + } if (cfg_.layout_optimizer() != RewriterConfig::OFF) { optimizers.push_back( std::unique_ptr(new LayoutOptimizer())); } - if (cfg_.memory_optimization() > 1) { + if (cfg_.memory_optimization() != RewriterConfig::NO_MEM_OPT) { if (cfg_.memory_optimizer_target_node_name_prefix().empty()) { optimizers.push_back(std::unique_ptr( // Use the default target node name prefix "gradients/" @@ -118,14 +127,26 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, new AutoParallel(cfg_.auto_parallel().num_replicas()))); } } else { - std::set available_optimizers = { - "pruning", "constfold", "layout", "memory", - "autoparallel", "arithmetic", "dependency"}; - for (const auto& optimizer : cfg_.optimizers()) { - if (available_optimizers.find(optimizer) != available_optimizers.end()) { - optimizers.push_back(NewOptimizer(optimizer)); + const std::set available_optimizers = { + "pruning", "constfold", "layout", "memory", + "autoparallel", "arithmetic", "dependency", "loop"}; + std::vector custom_optimizer_names; + for (const auto& optimizer_name : cfg_.optimizers()) { + if (available_optimizers.find(optimizer_name) != + available_optimizers.end()) { + optimizers.push_back(NewOptimizer(optimizer_name)); + } else { + custom_optimizer_names.push_back(optimizer_name); } } + // Now run the custom optimizers. + for (const auto& optimizer_name : custom_optimizer_names) { + std::unique_ptr opt = + CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name); + if (opt == nullptr) continue; + TF_RETURN_IF_ERROR(opt->Init()); + optimizers.push_back(std::move(opt)); + } } if (optimizers.empty()) { @@ -136,7 +157,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, bool already_optimized = false; for (const auto& optimizer : optimizers) { if (!already_optimized) { - auto status = optimizer->Optimize(cluster, item, optimized_graph); + Status status = optimizer->Optimize(cluster, item, optimized_graph); string result; if (!status.ok()) { VLOG(1) << "Not able to apply optimizer " << optimizer->name() @@ -152,7 +173,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, << " return status: " << result; } else { GrapplerItem optimized_item(item, std::move(*optimized_graph)); - auto status = + Status status = optimizer->Optimize(cluster, optimized_item, optimized_graph); string result; if (!status.ok()) { @@ -204,8 +225,10 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { cfg.layout_optimizer() != RewriterConfig::OFF || cfg.constant_folding() != RewriterConfig::OFF || cfg.dependency_optimization() != RewriterConfig::OFF || + cfg.loop_optimization() == RewriterConfig::ON || cfg.arithmetic_optimization() != RewriterConfig::OFF || - cfg.auto_parallel().enable() || cfg.memory_optimization() > 1 || + cfg.auto_parallel().enable() || + cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || !cfg.optimizers().empty(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..536347d8348738e1755e920f3f08c2d4858cb256 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/meta_optimizer.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class TestOptimizer : public CustomGraphOptimizer { + public: + static void SetOptimized(const bool flag_value) { optimized_ = flag_value; } + static bool IsOptimized() { return optimized_; } + + TestOptimizer() {} + string name() const override { return "test_optimizer"; } + + Status Init() override { return Status::OK(); } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + optimized_ = true; + *optimized_graph = item.graph; + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + static bool optimized_; +}; + +bool TestOptimizer::optimized_; + +REGISTER_GRAPH_OPTIMIZER(TestOptimizer); + +TEST(MetaOptimizerTest, RunsCustomOptimizer) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizer"); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/model_pruner.cc b/tensorflow/core/grappler/optimizers/model_pruner.cc index ece9df012e5bd12b351c45881d75f54a46c4d459..3311e970108d94d34a92842d51aca8f0c99d904c 100644 --- a/tensorflow/core/grappler/optimizers/model_pruner.cc +++ b/tensorflow/core/grappler/optimizers/model_pruner.cc @@ -50,7 +50,7 @@ bool IsTrivialOp(const NodeDef& node, const GraphRewriter& rewriter) { Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* pruned_graph) { - const std::unordered_set& nodes_to_preserve = item.NodesToPreserve(); + const std::unordered_set nodes_to_preserve = item.NodesToPreserve(); // Prune all the nodes that won't be executed, ie all the nodes that aren't in // the fanin of a fetch node. If fetch nodes aren't specified, we'll assume @@ -59,6 +59,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, if (!nodes_to_preserve.empty()) { std::vector terminal_nodes(nodes_to_preserve.begin(), nodes_to_preserve.end()); + std::sort(terminal_nodes.begin(), terminal_nodes.end()); bool ill_formed = false; std::vector keep = ComputeTransitiveFanin(item.graph, terminal_nodes, &ill_formed); @@ -67,7 +68,7 @@ Status ModelPruner::Optimize(Cluster* cluster, const GrapplerItem& item, // let's be conservative and preserve the graph as is. return errors::InvalidArgument("Invalid input graph."); } - // Try to keep the nodes ordored somewhat topologically since this helps + // Try to keep the nodes ordered somewhat topologically since this helps // further optimizations perform better. for (int i = keep.size() - 1; i >= 0; --i) { *runnable_item.graph.add_node() = *keep[i]; diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index eb5a2c48dc8b12f7b4090e80c403e238a526e122..81bb5e6c3b26ebbed8cd1555c10d2dd6f2a47c12 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -29,6 +29,18 @@ limitations under the License. namespace tensorflow { namespace grappler { +namespace { +template +bool SafeSetScalarTensorValue(double value, Tensor* tensor) { + using RealType = typename Eigen::NumTraits::Real; + if (value > std::numeric_limits::max() || + value < std::numeric_limits::min()) { + return false; + } + tensor->flat()(0) = static_cast(value); + return true; +} +} // namespace NodeMap::NodeMap(GraphDef* graph) { CHECK(graph != nullptr); @@ -402,5 +414,43 @@ string SimpleGraphView::PrintToString() const { return str; } +#define HANDLE_CASE(DTYPE) \ + case DTYPE: \ + if (!SafeSetScalarTensorValue::Type>( \ + static_cast(value), tensor)) { \ + return errors::InvalidArgument("Cannot store value ", value, \ + " in tensor of type " #DTYPE); \ + } \ + break + +Status SetTensorValue(DataType dtype, int value, Tensor* tensor) { + // TODO(rmlarsen): Support more general shapes. + if (tensor->NumElements() != 1) { + return errors::InvalidArgument( + "Expected scalar tensor, got num_elements = ", tensor->NumElements()); + } + switch (dtype) { + // TODO(rmlarsen): Handle DT_HALF. + // HANDLE_CASE(DT_HALF); + HANDLE_CASE(DT_BOOL); + HANDLE_CASE(DT_FLOAT); + HANDLE_CASE(DT_DOUBLE); + HANDLE_CASE(DT_UINT8); + HANDLE_CASE(DT_INT8); + HANDLE_CASE(DT_UINT16); + HANDLE_CASE(DT_INT16); + HANDLE_CASE(DT_INT32); + HANDLE_CASE(DT_INT64); + HANDLE_CASE(DT_COMPLEX64); + HANDLE_CASE(DT_COMPLEX128); + default: + return errors::InvalidArgument("Unsupported type ", + DataTypeString(dtype)); + } + return Status::OK(); +} + +#undef HANDLE_CASE + } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 4ecb28f681507f50ad5909f15cf1b408ed6e2979..255319693a57a7cc493365a51d5d04d2893f08c5 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -167,6 +168,8 @@ NodeDef* GetTailOfChain(const NodeDef& source, const NodeMap& node_map, void PermuteNodesInPlace(GraphDef* graph, std::vector* permutation, bool invert_permutation); +Status SetTensorValue(DataType dtype, int value, Tensor* tensor); + class SimpleGraphView { public: Status Initialize(const GraphDef& graph) { diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index 0a9dbe22cfe3cd01c2c61661adcdd4839a957f03..5d32609434d0b0363b651604ff6fccb151723dd6 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -142,6 +142,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", + "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", ], ) diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 813f65f825759ca22dba2bdfd8433d946b7dd852..fef8e97b6e3a9102cec10262cf8afbd64ee424af 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -35,5 +35,60 @@ std::vector GrapplerTest::EvaluateNodes( return output_tensors; } +std::vector GrapplerTest::EvaluateFetchNodes(const GrapplerItem& item) { + SessionOptions options; + std::unique_ptr session(NewSession(options)); + TF_CHECK_OK(session->Create(item.graph)); + RunOptions run_options; + if (!item.init_ops.empty()) { + std::vector dummy; + TF_CHECK_OK( + session->Run(run_options, {}, {}, item.init_ops, &dummy, nullptr)); + } + std::vector output_tensors; + TF_CHECK_OK( + session->Run(run_options, {}, item.fetch, {}, &output_tensors, nullptr)); + TF_CHECK_OK(session->Close()); + return output_tensors; +} + +void GrapplerTest::AddNode(const string& name, const string& op, + const std::vector& inputs, GraphDef* graph) { + auto* node = graph->add_node(); + node->set_name(name); + node->set_op(op); + for (const auto& input : inputs) { + node->add_input(input); + } +} + +void GrapplerTest::CompareGraphs(GraphDef want, GraphDef got) { + auto comparator = [](const NodeDef& n1, const NodeDef& n2) -> bool { + return n1.name() < n2.name(); + }; + std::sort(want.mutable_node()->begin(), want.mutable_node()->end(), + comparator); + std::sort(got.mutable_node()->begin(), got.mutable_node()->end(), comparator); + + for (int i = 0; i < want.node_size(); ++i) { + std::sort(want.mutable_node(i)->mutable_input()->begin(), + want.mutable_node(i)->mutable_input()->end()); + } + for (int i = 0; i < got.node_size(); ++i) { + std::sort(got.mutable_node(i)->mutable_input()->begin(), + got.mutable_node(i)->mutable_input()->end()); + } + + ASSERT_EQ(want.node_size(), got.node_size()); + for (int i = 0; i < want.node_size(); ++i) { + EXPECT_EQ(want.node(i).op(), got.node(i).op()); + EXPECT_EQ(want.node(i).name(), got.node(i).name()); + ASSERT_EQ(want.node(i).input_size(), got.node(i).input_size()); + for (int j = 0; j < want.node(i).input_size(); ++j) { + EXPECT_TRUE(IsSameInput(want.node(i).input(j), got.node(i).input(j))); + } + } +} + } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/utils/grappler_test.h b/tensorflow/core/grappler/utils/grappler_test.h index 46ce47c8c3b6bc18b6eac76bbdb8ec1f8a58fab2..fd6809b6e21b87bf5420898def17ea17bc0b427b 100644 --- a/tensorflow/core/grappler/utils/grappler_test.h +++ b/tensorflow/core/grappler/utils/grappler_test.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -29,6 +30,13 @@ class GrapplerTest : public ::testing::Test { protected: std::vector EvaluateNodes(const GraphDef& graph, const std::vector& node_names); + + std::vector EvaluateFetchNodes(const GrapplerItem& item); + + void AddNode(const string& name, const string& op, + const std::vector& inputs, GraphDef* graph); + + void CompareGraphs(GraphDef want, GraphDef got); }; } // end namespace grappler diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 523e3956996de2f1cd5a5626b15dfff73022a9d5..78786de16bf16fcbc156a46751f721d2c3e664ff 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -56,8 +56,8 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") config_setting( # Add "--define tensorflow_xsmm=1" to your build command to use libxsmm for - # convolutions (and possibly more in the future). You will also need - # appropriate -mavx*, as required by specific op you use. + # sparse matrix multiplications. You will also need appropriate -mavx* + # options, as required by specific op you use. name = "xsmm", values = { "define": "tensorflow_xsmm=1", @@ -65,12 +65,23 @@ config_setting( ) config_setting( - # Add "--define tensorflow_xsmm_backward=1" to your build command to use - # libxsmm for backward convolutions (and possibly more in the future). You - # will also need appropriate -mavx*, as required by specific op you use. - name = "xsmm_backward", + # Add "--define tensorflow_xsmm_convolutions=1" to your build command to + # use libxsmm for forward convolutions. You will also need appropriate + # -mavx* # options, as required by specific op you use. + name = "xsmm_convolutions", values = { - "define": "tensorflow_xsmm_backward=1", + "define": "tensorflow_xsmm_convolutions=1", + }, +) + +config_setting( + # Add "--define tensorflow_xsmm_convolutions=1 --define + # tensorflow_xsmm_backward_convolutions=1" to your build command to use libxsmm for + # backward convolutions (and possibly more in the future). You will also + # need appropriate -mavx* options, as required by specific op you use. + name = "xsmm_backward_convolutions", + values = { + "define": "tensorflow_xsmm_backward_convolutions=1", }, ) @@ -987,6 +998,7 @@ tf_cuda_cc_test( name = "constant_op_test", size = "small", srcs = ["constant_op_test.cc"], + tags = ["no_cuda_on_cpu_tap"], deps = [ ":constant_op", ":ops_testutil", @@ -1016,7 +1028,7 @@ tf_cc_test( name = "xsmm_conv2d_test", size = "small", srcs = select({ - ":xsmm": ["xsmm_conv2d_test.cc"], + ":xsmm_convolutions": ["xsmm_conv2d_test.cc"], "//conditions:default": [], }), deps = [ @@ -1031,7 +1043,7 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", ] + select({ - ":xsmm": [ + ":xsmm_convolutions": [ "@libxsmm_archive//:xsmm_avx", ], "//conditions:default": [], @@ -1040,7 +1052,7 @@ tf_cc_test( tf_cc_test( name = "conv_ops_test", - size = "small", + size = "medium", srcs = ["conv_ops_test.cc"], deps = [ ":conv_ops", @@ -1890,9 +1902,9 @@ tf_kernel_library( srcs = ["resource_variable_ops.cc"], deps = [ ":bounds_check", - ":critical_section", ":dense_update_functor", ":gather_functor", + ":mutex_ops", ":scatter_functor", ":state", ":training_op_helpers", @@ -3137,7 +3149,7 @@ tf_kernel_library( "conv_grad_ops_3d.cc", "deep_conv2d.cc", ] + select({ - ":xsmm": ["xsmm_conv2d.cc"], + ":xsmm_convolutions": ["xsmm_conv2d.cc"], "//conditions:default": [], }), hdrs = [ @@ -3147,7 +3159,7 @@ tf_kernel_library( "gemm_functors.h", "winograd_transform.h", ] + select({ - ":xsmm": ["xsmm_conv2d.h"], + ":xsmm_convolutions": ["xsmm_conv2d.h"], "//conditions:default": [], }), # Override EIGEN_STRONG_INLINE to inline when --define=override_eigen_strong_inline=true, @@ -3155,13 +3167,15 @@ tf_kernel_library( # on Windows. See https://github.com/tensorflow/tensorflow/issues/10521 copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]), defines = select({ - ":xsmm": [ - "TENSORFLOW_USE_LIBXSMM", - "EIGEN_USE_LIBXSMM", + ":xsmm_convolutions": [ + "TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS", ], "//conditions:default": [], }) + select({ - ":xsmm_backward": ["TENSORFLOW_USE_LIBXSMM_BACKWARD"], + ":xsmm": ["EIGEN_USE_LIBXSMM"], + "//conditions:default": [], + }) + select({ + ":xsmm_backward_convolutions": ["TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS"], "//conditions:default": [], }), prefix = "conv_ops", @@ -3178,7 +3192,7 @@ tf_kernel_library( "//tensorflow/core:lib_internal", "//tensorflow/core:nn_ops_op_lib", ] + select({ - ":xsmm": [ + ":xsmm_convolutions": [ "@libxsmm_archive//:xsmm_avx", ], "//conditions:default": [], @@ -4093,9 +4107,9 @@ tf_kernel_library( ) tf_kernel_library( - name = "critical_section", - prefix = "critical_section", - deps = STATE_DEPS + [":captured_function"], + name = "mutex_ops", + prefix = "mutex_ops", + deps = STATE_DEPS + [":ops_util"], ) tf_cc_test( @@ -4867,7 +4881,7 @@ filegroup( "winograd_transform.h", ":android_extended_ops_headers", ] + select({ - ":xsmm": [ + ":xsmm_convolutions": [ "xsmm_conv2d.h", "xsmm_conv2d.cc", ], @@ -5047,7 +5061,7 @@ filegroup( # Excluded due to experimental status: "debug_ops.*", "scatter_nd_op*", - "critical_section.*", + "mutex_ops.*", "batch_kernels.*", ], ), @@ -5114,7 +5128,6 @@ tf_kernel_library( srcs = [ "dequantize_op.cc", "meta_support.cc", - "quantization_utils.cc", "quantize_down_and_shrink_range.cc", "quantize_op.cc", "quantized_activation_ops.cc", @@ -5135,7 +5148,6 @@ tf_kernel_library( ], hdrs = [ "meta_support.h", - "quantization_utils.h", "reference_gemm.h", ], deps = [ @@ -5146,6 +5158,7 @@ tf_kernel_library( ":image_resizer_state", ":ops_util", ":pooling_ops", + ":quantization_utils", "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", @@ -5692,6 +5705,16 @@ tf_kernel_library( ], ) +cc_library( + name = "quantization_utils", + srcs = ["quantization_utils.cc"], + hdrs = ["quantization_utils.h"], + deps = [ + "//tensorflow/core:framework", + "@gemmlowp", + ], +) + cc_library( name = "remote_fused_graph_execute_utils", srcs = [ @@ -6067,7 +6090,6 @@ cc_library( srcs = [ "cwise_ops_common.cc", "meta_support.cc", - "quantization_utils.cc", ], hdrs = [ "cwise_ops.h", @@ -6076,10 +6098,10 @@ cc_library( "cwise_ops_gpu_gradients.cu.h", "cwise_ops_gradients.h", "meta_support.h", - "quantization_utils.h", ], deps = [ ":bounds_check", + ":quantization_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//third_party/eigen3", diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h index 25c5f9cf424fdb286922548ea7ab0a35157a3502..661ed239d316d378dac97d994ae07eb147b9bca1 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h @@ -50,43 +50,26 @@ class ASBSQueue; // track of a number of queues (one per model or model version) which are // continuously enqueuing requests. The scheduler groups the requests into // batches which it periodically sends off for processing (see -// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler -// prioritizes batches by age (i.e. the batch's oldest request) irrespective of -// queue or batch size. +// shared_batch_scheduler.h for more details). AdaptiveSharedBatchScheduler +// (ASBS) prioritizes batches by age (i.e. the batch's oldest request) +// irrespective of queue or batch size. // -// The scheduling decision currently exists in two flavors, controlled by the -// option use_in_flight_batches_implementation. It is expected that setting this -// option to true will give universally better results; after a period of -// testing to confirm, the old implementation will be removed. -// -// If use_in_flight_batches_implementation is set to true, the scheduler -// limits the number of batches which can be processed concurrently. If a new -// batch is created, and the number of in flight batches is below the limit, -// the next (i.e. oldest) batch is immediately scheduled. Similarly, when a -// batch finishes processing, the limit is rechecked, and another batch may be -// scheduled. To avoid the need to carefully tune the limit for workload, -// model type, platform, etc, it is dynamically adjusted in order to provide the -// lowest latency. -// -// If use_in_flight_batches_implementation is set to false, the scheduler will -// process the oldest batch at an adjustable rate, regardless of batch size. -// The user can provide feedback to help set this rate to achieve some goal -// (i.e. minimize overall latency, limit cpu usage, etc). The rate (or rather, -// the corresponding period) is adjusted each time a batch is processed, using -// an exponentially weighted moving average to smooth noisy feedback: -// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N -// period *= (1 + K * emwa_feedback) +// ASBS tries to keep the system busy by maintaining an adjustable number of +// concurrently processed batches. If a new batch is created, and the number of +// in flight batches is below the target, the next (i.e. oldest) batch is +// immediately scheduled. Similarly, when a batch finishes processing, the +// target is rechecked, and another batch may be scheduled. To avoid the need +// to carefully tune the target for workload, model type, platform, etc, it is +// dynamically adjusted in order to provide the lowest average latency. // // Some potential use cases: // Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing // involves serial processing by a device, from a latency perspective it is // desirable to keep the device evenly loaded, avoiding the need to wait for // the device to process prior batches. -// feedback = num_pending_on_device() - desired_pending. // CPU utilization - If the batch processing is cpu dominated, you can reap // latency gains when underutilized by increasing the processing rate, but // back the rate off when the load increases to avoid overload. -// feedback = cpu_rate() - desired_cpu_rate. template class AdaptiveSharedBatchScheduler @@ -101,13 +84,17 @@ class AdaptiveSharedBatchScheduler struct Options { // The name to use for the pool of batch threads. string thread_pool_name = {"batch_threads"}; - // Number of batch processing threads; equivalently the maximum number of - // concurrently running batches. + // Number of batch processing threads - the maximum value of + // in_flight_batches_limit_. It is recommended that this value be set by + // running the system under load, observing the learned value for + // in_flight_batches_limit_, and setting this maximum to ~ 2x the value. + // Under low load, in_flight_batches_limit_ has no substantial effect on + // latency and therefore undergoes a random walk. Unreasonably large values + // for num_batch_threads allows for large in_flight_batches_limit_, which + // will harm latency for some time once load increases again. int64 num_batch_threads = port::NumSchedulableCPUs(); // The environment to use (typically only overridden by test code). Env* env = Env::Default(); - // Which implementation to use (described in class comments above). - bool use_in_flight_batches_implementation = false; // Initial limit for number of batches being concurrently processed. // Non-integer values correspond to probabilistic limits - i.e. a value of // 3.2 results in an actual cap of 3 80% of the time, and 4 20% of the time. @@ -116,28 +103,6 @@ class AdaptiveSharedBatchScheduler // numbers will give less noisy latency measurements, but will be less // responsive to changes in workload. int64 batches_to_average_over = 1000; - - // TODO(kte): remove the rate based implementation and corresponding options - // below once testing confirms the superiority of the in flight batches - // implementation. - // Initial batch scheduling period in microseconds. Will be altered for - // non-zero rate_feedback. - double initial_scheduling_period_micros = 500; - // Minimum batch scheduling period in microseconds. Recommend setting this - // value greater than 0, otherwise it may take a while to recover from a - // sustained time of negative scheduling_period_feedback (which may occur - // under low load). - double min_scheduling_period_micros = 100; - // Maximum batch scheduling period in microseconds. - double max_scheduling_period_micros = 10000; - // Feedback function used to modify the scheduling period each time a batch - // is scheduled. Should return values roughly O(1), with positive values - // resulting in an increased period. - std::function scheduling_period_feedback{[] { return 0.; }}; - // To handle potentially noisy scheduling_period_feedback, the period is - // adjusted using an exponentially weighted moving average over the previous - // feedback_smoothing_batches batches. Must be greater than 0. - int64 feedback_smoothing_batches = 10; }; // Ownership is shared between the caller of Create() and any queues created @@ -171,17 +136,11 @@ class AdaptiveSharedBatchScheduler explicit AdaptiveSharedBatchScheduler(const Options& options); - // Batch scheduling function which runs every scheduling_period_ microseconds. - // Only used when options_.use_in_flight_batches_implementation == false. - void ProcessOneBatch(); - // Tracks processing latency and adjusts in_flight_batches_limit to minimize. - // Only used when options_.use_in_flight_batches_implementation == true. void CallbackWrapper(const internal::ASBSBatch* batch, BatchProcessor callback); // Schedules batch if in_flight_batches_limit_ is not met. - // Only used when options_.use_in_flight_batches_implementation == true. void MaybeScheduleNextBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); // Notifies scheduler of non-empty batch which is eligible for processing. @@ -212,41 +171,22 @@ class AdaptiveSharedBatchScheduler mutex mu_; - // Responsible for running ProcessOneBatch. PeriodicFunction was used in order - // to check for deletion so that the thread can be shut down. - // Only used when options_.use_in_flight_batches_implementation == false. - std::unique_ptr scheduling_thread_; - // Responsible for running the batch processing callbacks. std::unique_ptr batch_thread_pool_; - // Time interval in microseconds between successive ProcessOneBatch calls. - // Only used when options_.use_in_flight_batches_implementation == false. - double scheduling_period_; - - // Exponentially weighted moving average of - // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch - // call. - // Only used when options_.use_in_flight_batches_implementation == false. - double ewma_feedback_ = 0; - // Limit on number of batches which can be concurrently processed. // Non-integer values correspond to probabilistic limits - i.e. a value of 3.2 // results in an actual cap of 3 80% of the time, and 4 20% of the time. - // Only used when options_.use_in_flight_batches_implementation == true. double in_flight_batches_limit_ GUARDED_BY(mu_); // Number of batches currently being processed. - // Only used when options_.use_in_flight_batches_implementation == true. int64 in_flight_batches_ GUARDED_BY(mu_) = 0; // RNG engine and distribution. - // Only used when options_.use_in_flight_batches_implementation == true. std::default_random_engine rand_engine_; std::uniform_real_distribution rand_double_; // Fields controlling the dynamic adjustment of in_flight_batches_limit_. - // Only used when options_.use_in_flight_batches_implementation == true. // Number of batches since the last in_flight_batches_limit_ adjustment. int64 batch_count_ GUARDED_BY(mu_) = 0; // Sum of processing latency for batches counted by batch_count_. @@ -348,32 +288,6 @@ Status AdaptiveSharedBatchScheduler::Create( return errors::InvalidArgument("num_batch_threads must be positive; was ", options.num_batch_threads); } - if (options.min_scheduling_period_micros < 0) { - return errors::InvalidArgument( - "min_scheduling_period_micros must be >= 0; was ", - options.min_scheduling_period_micros); - } - if (options.min_scheduling_period_micros > - options.initial_scheduling_period_micros) { - return errors::InvalidArgument( - "initial_scheduling_period_micros (", - options.initial_scheduling_period_micros, - ") must be >= min_scheduling_period_micros (", - options.min_scheduling_period_micros, ")"); - } - if (options.initial_scheduling_period_micros > - options.max_scheduling_period_micros) { - return errors::InvalidArgument( - "initial_scheduling_period_micros (", - options.initial_scheduling_period_micros, - ") must be <= max_scheduling_period_micros (", - options.max_scheduling_period_micros, ")"); - } - if (options.feedback_smoothing_batches < 1) { - return errors::InvalidArgument( - "feedback_smoothing_batches must be positive; was ", - options.feedback_smoothing_batches); - } if (options.initial_in_flight_batches_limit > options.num_batch_threads) { return errors::InvalidArgument( "initial_in_flight_batches_limit (", @@ -401,20 +315,12 @@ template AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( const Options& options) : options_(options), - scheduling_period_(options.initial_scheduling_period_micros), in_flight_batches_limit_(options.initial_in_flight_batches_limit), rand_double_(0.0, 1.0) { std::random_device device; rand_engine_.seed(device()); - PeriodicFunction::Options opts; - opts.thread_name_prefix = "scheduling_thread"; - opts.env = GetEnv(); batch_thread_pool_.reset(new thread::ThreadPool( GetEnv(), options.thread_pool_name, options.num_batch_threads)); - if (!options.use_in_flight_batches_implementation) { - scheduling_thread_.reset( - new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); - } } template @@ -443,9 +349,7 @@ void AdaptiveSharedBatchScheduler::AddBatch( const internal::ASBSBatch* batch) { mutex_lock l(mu_); batches_.push(batch); - if (options_.use_in_flight_batches_implementation) { - MaybeScheduleNextBatch(); - } + MaybeScheduleNextBatch(); } template @@ -523,44 +427,6 @@ void AdaptiveSharedBatchScheduler::CallbackWrapper( MaybeScheduleNextBatch(); } -template -void AdaptiveSharedBatchScheduler::ProcessOneBatch() { - static const double kFeedbackMultiplier = .001; - const internal::ASBSBatch* batch = nullptr; - BatchProcessor callback; - const int64 start_time_micros = GetEnv()->NowMicros(); - { - mutex_lock l(mu_); - if (!batches_.empty()) { - batch = batches_.top(); - batches_.pop(); - callback = queues_and_callbacks_[batch->queue()]; - } - } - if (batch != nullptr) { - double feedback = options_.scheduling_period_feedback(); - const int64 N = options_.feedback_smoothing_batches; - ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; - scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); - if (scheduling_period_ < options_.min_scheduling_period_micros) { - scheduling_period_ = options_.min_scheduling_period_micros; - } else if (scheduling_period_ > options_.max_scheduling_period_micros) { - scheduling_period_ = options_.max_scheduling_period_micros; - } - // Queue may destroy itself after ReleaseBatch is called. - batch->queue()->ReleaseBatch(batch); - batch_thread_pool_->Schedule([callback, batch] { - callback(std::unique_ptr>( - const_cast*>(batch))); - }); - } - const int64 sleep_time = - scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); - if (sleep_time > 0) { - GetEnv()->SleepForMicroseconds(sleep_time); - } -} - template bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( const internal::ASBSBatch* a, diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc index 8ae8ca02eca20b5d1184e6e588f013d59d10464a..109234287e8f274bc1f6903881ecc7cfecc0edbc 100644 --- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc +++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler_test.cc @@ -64,59 +64,6 @@ std::unique_ptr CreateFakeClockAdvancerThread( })); } -TEST(AdaptiveSharedBatchSchedulerTest, Basic) { - for (const bool delete_scheduler_early : {false, true}) { - for (const bool delete_queue_1_early : {false, true}) { - int queue_0_tasks = 0; - auto queue_0_callback = - [&queue_0_tasks](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - for (int i = 0; i < batch->num_tasks(); i++) { - queue_0_tasks += batch->task(i).size(); - } - }; - int queue_1_tasks = 0; - auto queue_1_callback = - [&queue_1_tasks](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - for (int i = 0; i < batch->num_tasks(); i++) { - queue_1_tasks += batch->task(i).size(); - } - }; - { - std::shared_ptr> scheduler; - TF_ASSERT_OK( - AdaptiveSharedBatchScheduler::Create({}, &scheduler)); - - // Create two queues. - std::unique_ptr> queue_0; - TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0)); - std::unique_ptr> queue_1; - TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1)); - - if (delete_scheduler_early) { - // Delete our copy of the scheduler. The queues should keep it alive - // under the covers. - scheduler = nullptr; - } - // Submit tasks to the two queues, and (optionally) remove the queues. - TF_ASSERT_OK(ScheduleTask(1, queue_0.get())); - TF_ASSERT_OK(ScheduleTask(2, queue_1.get())); - TF_ASSERT_OK(ScheduleTask(3, queue_0.get())); - TF_ASSERT_OK(ScheduleTask(4, queue_1.get())); - if (delete_queue_1_early) { - queue_1 = nullptr; - } - TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); - } - EXPECT_EQ(queue_0_tasks, 9); - EXPECT_EQ(queue_1_tasks, 6); - } - } -} - TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { using Scheduler = AdaptiveSharedBatchScheduler; std::shared_ptr scheduler; @@ -124,24 +71,6 @@ TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { options.num_batch_threads = 0; EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); options = Scheduler::Options(); - options.min_scheduling_period_micros = 50; - options.max_scheduling_period_micros = 100; - options.initial_scheduling_period_micros = 1; - EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); - options = Scheduler::Options(); - options.min_scheduling_period_micros = 50; - options.max_scheduling_period_micros = 100; - options.initial_scheduling_period_micros = 1000; - EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); - options = Scheduler::Options(); - options.min_scheduling_period_micros = 100; - options.max_scheduling_period_micros = 50; - options.initial_scheduling_period_micros = 75; - EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); - options = Scheduler::Options(); - options.feedback_smoothing_batches = 0; - EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); - options = Scheduler::Options(); options.initial_in_flight_batches_limit = 0.5; EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); options = Scheduler::Options(); @@ -153,301 +82,8 @@ TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); } -TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { - test_util::FakeClockEnv env(Env::Default()); - Notification start_teardown, stop_teardown; - std::unique_ptr teardown_thread = - CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); - { - AdaptiveSharedBatchScheduler::Options options; - options.initial_scheduling_period_micros = 1000; - options.env = &env; - std::shared_ptr> scheduler; - TF_ASSERT_OK( - AdaptiveSharedBatchScheduler::Create(options, &scheduler)); - std::unique_ptr> queue_0; - std::unique_ptr> queue_1; - int queue_0_tasks = 0; - int queue_1_tasks = 0; - auto queue_0_callback = [&queue_0_tasks, - &env](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - for (int i = 0; i < batch->num_tasks(); i++) { - queue_0_tasks += batch->task(i).size(); - } - env.SleepForMicroseconds(1); - }; - auto queue_1_callback = [&queue_1_tasks, - &env](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - for (int i = 0; i < batch->num_tasks(); i++) { - queue_1_tasks += batch->task(i).size(); - } - env.SleepForMicroseconds(1); - }; - AdaptiveSharedBatchScheduler::QueueOptions queue_options; - queue_options.max_batch_size = 10; - queue_options.max_enqueued_batches = 0; - // Queue must have max_enqueued_batchs > 1. - EXPECT_FALSE( - scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok()); - queue_options.max_enqueued_batches = 2; - TF_ASSERT_OK( - scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); - EXPECT_EQ(10, queue_0->max_task_size()); - queue_options.max_batch_size = 0; - // Queue must have max_batch_size > 0. - EXPECT_FALSE( - scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok()); - queue_options.max_batch_size = 2; - queue_options.max_enqueued_batches = 1; - TF_ASSERT_OK( - scheduler->AddQueue(queue_options, queue_1_callback, &queue_1)); - - // Wait for scheduling_thread to sleep. - env.BlockUntilThreadsAsleep(1); - // Task larger than max_batch_size shouldn't schedule. - EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok()); - TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); - TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); - env.AdvanceByMicroseconds(1); - - // Task larger than max_batch_size shouldn't schedule. - EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok()); - TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); - TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); - env.AdvanceByMicroseconds(1); - // Exceeds max_enqueued_batches, shouldn't schedule. - EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok()); - - TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); - // Exceeds max_enqueued_batches, shouldn't schedule. - EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok()); - TF_ASSERT_OK(ScheduleTask(4, queue_0.get())); - - // Batches should be processed in order from oldest to newest. - env.AdvanceByMicroseconds(1000); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(queue_0_tasks, 10); - EXPECT_EQ(queue_1_tasks, 0); - - env.AdvanceByMicroseconds(1000); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(queue_0_tasks, 10); - EXPECT_EQ(queue_1_tasks, 2); - - env.AdvanceByMicroseconds(1000); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(queue_0_tasks, 19); - EXPECT_EQ(queue_1_tasks, 2); - start_teardown.Notify(); - } - stop_teardown.Notify(); -} - -TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) { - test_util::FakeClockEnv env(Env::Default()); - Notification start_teardown, stop_teardown; - std::unique_ptr teardown_thread = - CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); - { - double feedback = 0; - AdaptiveSharedBatchScheduler::Options options; - options.initial_scheduling_period_micros = 1000; - options.min_scheduling_period_micros = 200; - options.max_scheduling_period_micros = 2000; - options.env = &env; - options.scheduling_period_feedback = [&feedback] { return feedback; }; - options.feedback_smoothing_batches = 1; - std::shared_ptr> scheduler; - TF_ASSERT_OK( - AdaptiveSharedBatchScheduler::Create(options, &scheduler)); - std::unique_ptr> queue; - int scheduled_items = 0; - auto queue_callback = [&scheduled_items, - &env](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - scheduled_items = 0; - for (int i = 0; i < batch->num_tasks(); i++) { - scheduled_items += batch->task(i).size(); - } - env.SleepForMicroseconds(1); - }; - - TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); - - // Wait for scheduling_thread to sleep. - env.BlockUntilThreadsAsleep(1); - // Enqueue 6 batches. - for (int i = 0; i < 6; i++) { - TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); - env.AdvanceByMicroseconds(1); - } - feedback = -500; - env.AdvanceByMicroseconds(994); - env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec. - EXPECT_EQ(scheduled_items, 900); - env.AdvanceByMicroseconds(500); - env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. - EXPECT_EQ(scheduled_items, 901); - feedback = 0; - env.AdvanceByMicroseconds(250); - env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. - EXPECT_EQ(scheduled_items, 902); - feedback = 10000; // large feedback should hit max_scheduling_period. - env.AdvanceByMicroseconds(250); - env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec. - EXPECT_EQ(scheduled_items, 903); - feedback = -10000; // large feedback should hit min_scheduling_period. - env.AdvanceByMicroseconds(1999); - // No callback scheduled, only scheduling thread sleeping. - env.BlockUntilThreadsAsleep(1); - EXPECT_EQ(scheduled_items, 903); - env.AdvanceByMicroseconds(1); - env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec. - EXPECT_EQ(scheduled_items, 904); - env.AdvanceByMicroseconds(200); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(scheduled_items, 905); - start_teardown.Notify(); - } - stop_teardown.Notify(); -} - -TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) { - test_util::FakeClockEnv env(Env::Default()); - Notification start_teardown, stop_teardown; - std::unique_ptr teardown_thread = - CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); - { - double feedback = 0; - AdaptiveSharedBatchScheduler::Options options; - options.initial_scheduling_period_micros = 1000; - options.env = &env; - options.scheduling_period_feedback = [&feedback] { return feedback; }; - options.feedback_smoothing_batches = 3; - std::shared_ptr> scheduler; - TF_ASSERT_OK( - AdaptiveSharedBatchScheduler::Create(options, &scheduler)); - std::unique_ptr> queue; - int scheduled_items = 0; - auto queue_callback = [&scheduled_items, - &env](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - scheduled_items = 0; - for (int i = 0; i < batch->num_tasks(); i++) { - scheduled_items += batch->task(i).size(); - } - env.SleepForMicroseconds(1); - }; - - TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); - - // Wait for scheduling_thread to sleep. - env.BlockUntilThreadsAsleep(1); - // Enqueue 4 batches. - for (int i = 0; i < 4; i++) { - TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); - env.AdvanceByMicroseconds(1); - } - feedback = -300; - env.AdvanceByMicroseconds(996); - env.BlockUntilThreadsAsleep(2); - // ewma_feedback = 100, scheduling_period = 900. - EXPECT_EQ(scheduled_items, 900); - env.AdvanceByMicroseconds(899); - // No callback scheduled, only scheduling thread sleeping. - env.BlockUntilThreadsAsleep(1); - EXPECT_EQ(scheduled_items, 900); - env.AdvanceByMicroseconds(1); - env.BlockUntilThreadsAsleep(2); - // ewma_feedback = 167, scheduling_period = 750. - EXPECT_EQ(scheduled_items, 901); - env.AdvanceByMicroseconds(749); - // No callback scheduled, only scheduling thread sleeping. - env.BlockUntilThreadsAsleep(1); - EXPECT_EQ(scheduled_items, 901); - feedback = 1000 / 3.; - env.AdvanceByMicroseconds(1); - env.BlockUntilThreadsAsleep(2); - // emwa_feedback = 0, scheduling_period = 750. - EXPECT_EQ(scheduled_items, 902); - env.AdvanceByMicroseconds(749); - // No callback scheduled, only scheduling thread sleeping. - env.BlockUntilThreadsAsleep(1); - EXPECT_EQ(scheduled_items, 902); - env.AdvanceByMicroseconds(1); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(scheduled_items, 903); - start_teardown.Notify(); - } - stop_teardown.Notify(); -} - -TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { - test_util::FakeClockEnv env(Env::Default()); - Notification start_teardown, stop_teardown; - std::unique_ptr teardown_thread = - CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); - { - AdaptiveSharedBatchScheduler::Options options; - options.initial_scheduling_period_micros = 1000; - options.env = &env; - std::shared_ptr> scheduler; - TF_ASSERT_OK( - AdaptiveSharedBatchScheduler::Create(options, &scheduler)); - std::unique_ptr> queue; - int scheduled_items = 0; - auto queue_callback = [&scheduled_items, - &env](std::unique_ptr> batch) { - ASSERT_TRUE(batch->IsClosed()); - EXPECT_GT(batch->num_tasks(), 0); - scheduled_items = 0; - for (int i = 0; i < batch->num_tasks(); i++) { - scheduled_items += batch->task(i).size(); - } - env.SleepForMicroseconds(1); - }; - AdaptiveSharedBatchScheduler::QueueOptions queue_options; - queue_options.max_batch_size = 10; - queue_options.max_enqueued_batches = 10; - TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue)); - - // Wait for scheduling_thread to sleep. - env.BlockUntilThreadsAsleep(1); - // Enqueue 3 tasks. - EXPECT_EQ(queue->NumEnqueuedTasks(), 0); - EXPECT_EQ(queue->SchedulingCapacity(), 100); - TF_ASSERT_OK(ScheduleTask(5, queue.get())); - EXPECT_EQ(queue->NumEnqueuedTasks(), 1); - EXPECT_EQ(queue->SchedulingCapacity(), 95); - env.AdvanceByMicroseconds(1); - TF_ASSERT_OK(ScheduleTask(6, queue.get())); - EXPECT_EQ(queue->NumEnqueuedTasks(), 2); - EXPECT_EQ(queue->SchedulingCapacity(), 84); - env.AdvanceByMicroseconds(1); - TF_ASSERT_OK(ScheduleTask(1, queue.get())); - EXPECT_EQ(queue->NumEnqueuedTasks(), 3); - EXPECT_EQ(queue->SchedulingCapacity(), 83); - - env.AdvanceByMicroseconds(998); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(scheduled_items, 5); - env.AdvanceByMicroseconds(1000); - env.BlockUntilThreadsAsleep(2); - EXPECT_EQ(scheduled_items, 7); - start_teardown.Notify(); - } - stop_teardown.Notify(); -} - -TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) { +TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimit) { AdaptiveSharedBatchScheduler::Options options; - options.use_in_flight_batches_implementation = true; options.initial_in_flight_batches_limit = 2; options.batches_to_average_over = 1000; mutex mu; @@ -476,7 +112,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesImplementation) { std::unique_ptr> queue; TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); - // Enqueue 3 batches. + // Enqueue 3 tasks, should result in 3 batches. for (int i = 0; i < 3; i++) { TF_ASSERT_OK(ScheduleTask(100, queue.get())); } @@ -490,7 +126,6 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) { { AdaptiveSharedBatchScheduler::Options options; options.env = &env; - options.use_in_flight_batches_implementation = true; options.initial_in_flight_batches_limit = 2; options.batches_to_average_over = 1; auto queue_callback = [&env](std::unique_ptr> batch) { @@ -544,6 +179,125 @@ TEST(AdaptiveSharedBatchSchedulerTest, InFlightBatchesLimitTuning) { } stop_teardown.Notify(); } + +TEST(AdaptiveSharedBatchSchedulerTest, DeleteQueue) { + AdaptiveSharedBatchScheduler::Options options; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + finish_processing.WaitForNotification(); + mu.lock(); + processed_batches++; + mu.unlock(); + }; + + std::unique_ptr queue_deleter; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 2 tasks, should result in 2 batches. + for (int i = 0; i < 2; i++) { + TF_ASSERT_OK(ScheduleTask(100, queue.get())); + } + // Delete queue, should be kept alive until empty. + queue_deleter.reset(Env::Default()->StartThread( + {}, "QueueDeleterThread", [&queue, &mu, &processed_batches] { + queue.reset(); + mutex_lock l(mu); + EXPECT_EQ(processed_batches, 2); + })); + // Give queue_deleter thread time to delete queue. + Env::Default()->SleepForMicroseconds(1000); + finish_processing.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, DeleteScheduler) { + AdaptiveSharedBatchScheduler::Options options; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + finish_processing.WaitForNotification(); + mu.lock(); + processed_batches++; + mu.unlock(); + }; + + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 2 tasks, should result in 2 batches. + for (int i = 0; i < 2; i++) { + TF_ASSERT_OK(ScheduleTask(100, queue.get())); + } + // Delete scheduler, should be kept alive until queues are empty. + scheduler.reset(); + finish_processing.Notify(); + while (true) { + mutex_lock l(mu); + if (processed_batches == 2) break; + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { + AdaptiveSharedBatchScheduler::Options options; + options.initial_in_flight_batches_limit = 1; + options.batches_to_average_over = 1000; + mutex mu; + int processed_batches = 0; + Notification finish_processing; + auto queue_callback = [&mu, &processed_batches, &finish_processing]( + std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + mu.lock(); + int batch_num = ++processed_batches; + mu.unlock(); + if (batch_num == 1) { + finish_processing.WaitForNotification(); + } + }; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Enqueue 2 tasks, should result in 2 batches. + for (int i = 0; i < 2; i++) { + TF_ASSERT_OK(ScheduleTask(100, queue.get())); + } + // First batch was immediately processed, no longer counts as enqueued. + EXPECT_EQ(queue->NumEnqueuedTasks(), 1); + EXPECT_EQ(queue->SchedulingCapacity(), 9 * 1000 + 900); + // Enqueue 2 more tasks, should fall in same batch. + TF_ASSERT_OK(ScheduleTask(100, queue.get())); + TF_ASSERT_OK(ScheduleTask(200, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 3); + EXPECT_EQ(queue->SchedulingCapacity(), 9 * 1000 + 600); + // Enqueue 1 more task, should create new batch. + TF_ASSERT_OK(ScheduleTask(700, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 4); + EXPECT_EQ(queue->SchedulingCapacity(), 8 * 1000 + 300); + finish_processing.Notify(); +} } // namespace anonymous } // namespace serving } // namespace tensorflow diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc index 534527c6bdc9ab971cd4c6001dcef8ee59a13a8d..6040b2b3999770bdd9e39e5209b6b0a1918e1d8e 100644 --- a/tensorflow/core/kernels/check_numerics_op.cc +++ b/tensorflow/core/kernels/check_numerics_op.cc @@ -47,6 +47,8 @@ template class CheckNumericsOp; // Partial specialization for CPU +// TODO(jeff,rmlarsen): We should make this variant be an AsyncOpKernel, as +// was done for the GPU case below. template class CheckNumericsOp : public OpKernel { public: @@ -67,28 +69,32 @@ class CheckNumericsOp : public OpKernel { int fp_props = std::accumulate(data, data + size, 0, [](const int& x, const T& y) { int result = x; - if (Eigen::numext::isinf(y)) { + if (TF_PREDICT_TRUE(Eigen::numext::isfinite(y))) { + // Do nothing: common case + } else if (Eigen::numext::isinf(y)) { result |= kInfBit; } else if (Eigen::numext::isnan(y)) { result |= kNaNBit; } return result; }); - string status; - if ((fp_props & kInfBit) && (fp_props & kNaNBit)) { - status = "Inf and NaN"; - } else { - if (fp_props & kInfBit) { - status = "Inf"; + if (fp_props != 0) { + string status; + if ((fp_props & kInfBit) && (fp_props & kNaNBit)) { + status = "Inf and NaN"; + } else { + if (fp_props & kInfBit) { + status = "Inf"; + } + if (fp_props & kNaNBit) { + status = "NaN"; + } } - if (fp_props & kNaNBit) { - status = "NaN"; + if (!status.empty()) { + context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", + status, " values")); } } - if (!status.empty()) { - context->SetStatus(errors::InvalidArgument(message_, " : Tensor had ", - status, " values")); - } } private: diff --git a/tensorflow/core/kernels/colorspace_op.cc b/tensorflow/core/kernels/colorspace_op.cc index 9cc2e67bbe1f6919d581def55eb4315f7b908ca3..f4402a245d6c3848430126b3250731008c954df0 100644 --- a/tensorflow/core/kernels/colorspace_op.cc +++ b/tensorflow/core/kernels/colorspace_op.cc @@ -71,7 +71,7 @@ class RGBToHSVOp : public OpKernel { TensorShape({input_data.dimension(0)}), &trange)); - typename TTypes::Tensor range = trange.tensor(); + typename TTypes::Tensor range(trange.tensor()); functor::RGBToHSV()(context->eigen_device(), input_data, range, output_data); diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 4ab6fdbca1a3415937213d46fac3058097130f55..312c1a41d36245ae3ca5a09d2e76a430bc464953 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -102,9 +102,15 @@ REGISTER_KERNEL(GPU, float); REGISTER_KERNEL(GPU, double); REGISTER_KERNEL(GPU, uint8); REGISTER_KERNEL(GPU, int8); +REGISTER_KERNEL(GPU, qint8); REGISTER_KERNEL(GPU, uint16); REGISTER_KERNEL(GPU, int16); +REGISTER_KERNEL(GPU, qint16); +REGISTER_KERNEL(GPU, quint16); +REGISTER_KERNEL(GPU, uint32); +REGISTER_KERNEL(GPU, qint32); REGISTER_KERNEL(GPU, int64); +REGISTER_KERNEL(GPU, uint64); REGISTER_KERNEL(GPU, complex64); REGISTER_KERNEL(GPU, complex128); REGISTER_KERNEL(GPU, bool); @@ -121,9 +127,15 @@ REGISTER_SYCL_KERNEL(SYCL, float); REGISTER_SYCL_KERNEL(SYCL, double); REGISTER_SYCL_KERNEL(SYCL, uint8); REGISTER_SYCL_KERNEL(SYCL, int8); +REGISTER_SYCL_KERNEL(SYCL, qint8); REGISTER_SYCL_KERNEL(SYCL, uint16); REGISTER_SYCL_KERNEL(SYCL, int16); +REGISTER_SYCL_KERNEL(SYCL, qint16); +REGISTER_SYCL_KERNEL(SYCL, quint16); +REGISTER_SYCL_KERNEL(SYCL, uint32); +REGISTER_SYCL_KERNEL(SYCL, qint32); REGISTER_SYCL_KERNEL(SYCL, int64); +REGISTER_SYCL_KERNEL(SYCL, uint64); REGISTER_SYCL_KERNEL(SYCL, bool); #undef REGISTER_SYCL_KERNEL #endif diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index 2142207b0d89a4b2f02c7f7b5d320c3b4b48462c..6949e5b5fd85f399473095f26314e9d58fa65464 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -54,10 +54,12 @@ struct InflatePadAndShuffle { template void SpatialConvolutionFunc(const Device& d, Output output, Input input, Filter filter, int row_stride, int col_stride, + int row_dilation, int col_dilation, const Eigen::PaddingType& padding) { // Need to swap row/col when calling Eigen. output.device(d) = - Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding); + Eigen::SpatialConvolution(input, filter, col_stride, row_stride, padding, + col_dilation, row_dilation); } template @@ -65,9 +67,10 @@ struct SpatialConvolution { void operator()(const Device& d, typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, int row_stride, - int col_stride, const Eigen::PaddingType& padding) { + int col_stride, int row_dilation, int col_dilation, + const Eigen::PaddingType& padding) { SpatialConvolutionFunc(d, output, input, filter, row_stride, col_stride, - padding); + row_dilation, col_dilation, padding); } }; @@ -77,11 +80,12 @@ struct SpatialConvolution { typename TTypes::Tensor output, typename TTypes::ConstTensor input, typename TTypes::ConstTensor filter, - int row_stride, int col_stride, - const Eigen::PaddingType& padding) { + int row_stride, int col_stride, int row_dilation, + int col_dilation, const Eigen::PaddingType& padding) { output.device(d) = Eigen::SpatialConvolution(input.cast(), filter.cast(), - col_stride, row_stride, padding) + col_stride, row_stride, padding, col_dilation, + row_dilation) .cast(); } }; @@ -91,11 +95,13 @@ struct SpatialConvolutionBackwardInput { void operator()(const Device& d, typename TTypes::Tensor input_backward, typename TTypes::ConstTensor kernel, typename TTypes::ConstTensor output_backward, - int row_stride, int col_stride) { + int row_stride, int col_stride, int row_dilation, + int col_dilation) { // Need to swap row/col when calling Eigen. input_backward.device(d) = Eigen::SpatialConvolutionBackwardInput( kernel, output_backward, input_backward.dimension(2), - input_backward.dimension(1), col_stride, row_stride); + input_backward.dimension(1), col_stride, row_stride, col_dilation, + row_dilation); } }; @@ -105,11 +111,13 @@ struct SpatialConvolutionBackwardFilter { typename TTypes::Tensor kernel_backward, typename TTypes::ConstTensor input, typename TTypes::ConstTensor output_backward, - int row_stride, int col_stride) { + int row_stride, int col_stride, int row_dilation, + int col_dilation) { // Need to swap row/col when calling Eigen. kernel_backward.device(d) = Eigen::SpatialConvolutionBackwardKernel( input, output_backward, kernel_backward.dimension(1), - kernel_backward.dimension(0), col_stride, row_stride); + kernel_backward.dimension(0), col_stride, row_stride, col_dilation, + row_dilation); } }; diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 512bcc6c01bf3eb4aed92f90eebb060abda8a7fc..e6ae59529107e529a9ccf7c790da0da62c90c199 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/fill_functor.h" -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS #include "tensorflow/core/kernels/xsmm_conv2d.h" #endif #include "tensorflow/core/kernels/ops_util.h" @@ -101,11 +101,12 @@ struct LaunchConv2DBackpropFilterOp { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardFilter()( d, filter_backprop->tensor(), input.tensor(), - out_backprop.tensor(), row_stride, col_stride); + out_backprop.tensor(), row_stride, col_stride, + /*row_dilation=*/1, /*col_dilation=*/1); } }; -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS template struct LaunchXsmmBackwardFilter { bool operator()(OpKernelContext* context, const Device& d, @@ -242,7 +243,8 @@ class Conv2DFastBackpropFilterOp : public OpKernel { return; } -#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD +#if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \ + defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS int64 pad_top, pad_bottom; int64 pad_left, pad_right; OP_REQUIRES_OK( @@ -370,7 +372,8 @@ class Conv2DCustomBackpropFilterOp : public OpKernel { dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size, dims.spatial_dims[1].stride, padding_, &dims.spatial_dims[1].output_size, &pad_left, &pad_right)); -#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD +#if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \ + defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS if (pad_left == pad_right && pad_top == pad_bottom) { if (LaunchXsmmBackwardFilter()( context, context->eigen_device(), input.tensor(), diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 0356ff4c0f4240ec806d1e337546cfce6771d92f..15c55e4d9903b3bbd53e1b6e1c95571ef7834015 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS #include "tensorflow/core/kernels/xsmm_conv2d.h" #endif #include "tensorflow/core/kernels/ops_util.h" @@ -106,11 +106,12 @@ struct LaunchConv2DBackpropInputOp { const CPUDevice& d = ctx->eigen_device(); functor::SpatialConvolutionBackwardInput()( d, in_backprop->tensor(), filter.tensor(), - out_backprop.tensor(), row_stride, col_stride); + out_backprop.tensor(), row_stride, col_stride, + /*row_dilation=*/1, /*col_dilation=*/1); } }; -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS template struct LaunchXsmmBackwardInputConvolution { bool operator()(OpKernelContext* context, const Device& d, @@ -245,7 +246,8 @@ class Conv2DFastBackpropInputOp : public OpKernel { return; } -#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD +#if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \ + defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS int64 pad_top, pad_bottom; int64 pad_left, pad_right; OP_REQUIRES_OK( @@ -362,7 +364,8 @@ class Conv2DCustomBackpropInputOp : public OpKernel { // TODO(andydavis) Consider moving code shared with // Conv2DCustomBackpropFilterOp into a shared helper function. -#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD +#if defined TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS && \ + defined TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS int64 pad_top, pad_bottom; int64 pad_left, pad_right; OP_REQUIRES_OK( diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index dbddaf3dc640dcf2cad8f6ba7dd00aaa33a30e0c..47f6907c04b4e48607e66b5c9601cd9030fa9001 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -32,7 +32,7 @@ limitations under the License. #include "tensorflow/core/kernels/conv_2d.h" #include "tensorflow/core/kernels/deep_conv2d.h" #include "tensorflow/core/kernels/ops_util.h" -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS #include "tensorflow/core/kernels/xsmm_conv2d.h" #endif #include "tensorflow/core/lib/core/errors.h" @@ -60,8 +60,8 @@ template struct LaunchGeneric { void operator()(OpKernelContext* ctx, const Tensor& input, const Tensor& filter, int row_stride, int col_stride, - const Padding& padding, Tensor* output, - TensorFormat data_format) { + int row_dilation, int col_dilation, const Padding& padding, + Tensor* output, TensorFormat data_format) { CHECK(data_format == FORMAT_NHWC) << "Generic conv implementation only " "supports NHWC tensor format for now."; if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 && @@ -86,7 +86,8 @@ struct LaunchGeneric { filter.shaped({filter.dim_size(2), filter.dim_size(3)}), dim_pair); } else if (filter.dim_size(0) == input.dim_size(1) && - filter.dim_size(1) == input.dim_size(2) && padding == VALID) { + filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 && + col_dilation == 1 && padding == VALID) { // If the input data and filter have the same height/width, // the 2D convolution is reduced to matrix multiplication. const int k = // Length of reduction dimension. @@ -103,7 +104,7 @@ struct LaunchGeneric { functor::SpatialConvolution()( ctx->eigen_device(), output->tensor(), input.tensor(), filter.tensor(), row_stride, col_stride, - BrainPadding2EigenPadding(padding)); + row_dilation, col_dilation, BrainPadding2EigenPadding(padding)); } } }; @@ -122,15 +123,9 @@ struct LaunchConv2DOp { "NHWC tensor format for now.")); return; } - // TODO(yangzihao): Add the CPU implementation of dilated conv 2D. - if (row_dilation > 1 || col_dilation > 1) { - ctx->SetStatus( - errors::Unimplemented("Generic conv implementation only supports " - "dilated rate of 1 for now.")); - return; - } LaunchGeneric()(ctx, input, filter, row_stride, col_stride, - padding, output, data_format); + row_dilation, col_dilation, padding, output, + data_format); } }; @@ -190,7 +185,7 @@ class LaunchDeepConvOp { } }; -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS template class LaunchXsmmConvOp { public: @@ -406,7 +401,7 @@ class Conv2DOp : public BinaryOp { return; } -#ifdef TENSORFLOW_USE_LIBXSMM +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS if (LaunchXsmmConvOp::Run( context, input, filter, batch, input_rows, input_cols, in_depth, filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, @@ -792,7 +787,8 @@ namespace functor { const GPUDevice& d, typename TTypes::Tensor output, \ typename TTypes::ConstTensor input, \ typename TTypes::ConstTensor filter, int row_stride, \ - int col_stride, const Eigen::PaddingType& padding); \ + int col_stride, int row_dilation, int col_dilation, \ + const Eigen::PaddingType& padding); \ extern template struct SpatialConvolution; \ template <> \ void MatMulConvFunctor::operator()( \ diff --git a/tensorflow/core/kernels/critical_section.cc b/tensorflow/core/kernels/critical_section.cc deleted file mode 100644 index 30a9abf4ee78cdb336e4c25c217239daf89bae11..0000000000000000000000000000000000000000 --- a/tensorflow/core/kernels/critical_section.cc +++ /dev/null @@ -1,246 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include -#include - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/kernels/captured_function.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -class CriticalSection : public ResourceBase { - public: - explicit CriticalSection() : is_locked_(false) {} - ~CriticalSection() override { - // Wait for all closures to finish running. - mutex_lock lock(mu_); - while (!closures_.empty()) { - queue_empty_cv_.wait(lock); - } - } - - private: - friend class ExecuteInCriticalSectionOp; - - void Acquire(std::function closure) { - std::function next; - { - mutex_lock ml(mu_); - if (is_locked_) { - closures_.push_back(std::move(closure)); - } else { - // This branch is the common case. Avoid the queue. - is_locked_ = true; - next = std::move(closure); - } - } - if (next) { - next(); - } - } - - void Release() { - std::function next; - { - mutex_lock ml(mu_); - CHECK(is_locked_); - if (!closures_.empty()) { - // if queue is not empty, start the next entry off the queue. - std::swap(next, closures_.front()); - closures_.pop_front(); - } else { - is_locked_ = false; - queue_empty_cv_.notify_all(); - } - } - if (next) { - next(); - } - } - - string DebugString() override { - tf_shared_lock ml(mu_); - return strings::StrCat("CriticalSection(locked: ", is_locked_, - " queue_size: ", closures_.size(), ")"); - } - - private: - mutex mu_; - std::deque> closures_ GUARDED_BY(mu_); - bool is_locked_ GUARDED_BY(mu_); - condition_variable queue_empty_cv_ GUARDED_BY(mu_); -}; - -class ExecuteInCriticalSectionOp : public AsyncOpKernel { - public: - explicit ExecuteInCriticalSectionOp(OpKernelConstruction* c) - : AsyncOpKernel(c) { - OP_REQUIRES_OK(c, c->GetAttr("f", &func_)); - } - - public: - void ComputeAsync(OpKernelContext* c, DoneCallback done) override { - CriticalSection* critical_section = nullptr; - OP_REQUIRES_OK_ASYNC(c, - LookupOrCreateResource( - c, HandleFromInput(c, 0), &critical_section, - [this, c](CriticalSection** ptr) { - *ptr = new CriticalSection; - return Status::OK(); - }), - done); - // No need to Unref critical_section; the Closure below will take - // care of the Unref associated with this execution. - - auto* execution = new Closure{std::move(done), c, critical_section, &func_}; - execution->Start(); - } - - private: - class Closure { - public: - AsyncOpKernel::DoneCallback done_; - OpKernelContext* ctx_; - CriticalSection* cs_; - FunctionLibraryRuntime::Handle handle_; - FunctionLibraryRuntime::Options opts_; - std::vector arguments_t_; - std::vector output_t_; - NameAttrList* func_; - - explicit Closure(AsyncOpKernel::DoneCallback done, OpKernelContext* ctx, - CriticalSection* critical_section, NameAttrList* func) - : done_(std::move(done)), - ctx_(ctx), - cs_(critical_section), - handle_(-1), - func_(func) {} - - ~Closure(); - - void Start() { - // Perform ExecuteFunction isnide a separate thread to avoid - // having lightweight Functions be inlined in this thread. - // That inlining would in turn inline DoneAndDelete inside the - // same thread. Since DoneAndDelete can call the next - // ExecuteFunction in the CriticalSection, this can cause a - // stack overflow. - cs_->Acquire( - [this]() { (*ctx_->runner())([this]() { ExecuteFunction(); }); }); - } - - private: - void ExecuteFunction(); - void DoneAndDelete(const Status& status); - }; - - NameAttrList func_; -}; - -void ExecuteInCriticalSectionOp::Closure::ExecuteFunction() { - // Arguments to a Function are in the order: - // concat(, ) - OpInputList arguments; - Status s = ctx_->input_list("arguments", &arguments); - if (!s.ok()) { - DoneAndDelete(s); - return; - } - - arguments_t_.reserve(arguments.size()); - for (const Tensor& t : arguments) { - arguments_t_.push_back(t); - } - - auto* function_library = ctx_->function_library(); - s = function_library->Instantiate(func_->name(), AttrSlice(&func_->attr()), - &handle_); - if (!s.ok()) { - DoneAndDelete(s); - return; - } - - opts_.step_id = CapturedFunction::generate_step_id(); - auto* step_container = - new ScopedStepContainer(opts_.step_id, [this](const string& name) { - ctx_->resource_manager()->Cleanup(name).IgnoreError(); - }); - opts_.cancellation_manager = ctx_->cancellation_manager(); - opts_.step_container = step_container; - opts_.runner = ctx_->runner(); - - function_library->Run(opts_, handle_, arguments_t_, &output_t_, - [this](const Status& s) { DoneAndDelete(s); }); -} - -void ExecuteInCriticalSectionOp::Closure::DoneAndDelete(const Status& status) { - cs_->Release(); - - if (!status.ok()) { - ctx_->SetStatus(status); - } else { - OpOutputList output; - const Status s = ctx_->output_list("outputs", &output); - if (!s.ok()) { - ctx_->SetStatus(s); - } else if (output_t_.size() != output.size()) { - ctx_->SetStatus(errors::Internal( - "Could not set all outputs. Expected output size is ", output.size(), - " but function set ", output_t_.size(), " output values.")); - } else { - for (int i = 0; i < output_t_.size(); ++i) { - output.set(i, output_t_[i]); - } - } - } - - delete opts_.step_container; - opts_.step_container = nullptr; - done_(); - cs_->Unref(); - delete this; -} - -ExecuteInCriticalSectionOp::Closure::~Closure() { - CHECK(!opts_.step_container) - << "Initialized closure destroyed without calling Done"; -} - -REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection").Device(DEVICE_CPU), - ExecuteInCriticalSectionOp); - -REGISTER_KERNEL_BUILDER(Name("CriticalSectionOp").Device(DEVICE_CPU), - ResourceHandleOp); - -// TODO(ebrevdo): Re-enable once the cross-device function execution works. -#if GOOGLE_CUDA -REGISTER_KERNEL_BUILDER(Name("ExecuteInCriticalSection") - .Device(DEVICE_GPU) - .HostMemory("critical_section"), - ExecuteInCriticalSectionOp); -REGISTER_KERNEL_BUILDER( - Name("CriticalSectionOp").Device(DEVICE_GPU).HostMemory("resource"), - ResourceHandleOp); -#endif // GOOGLE_CUDA - -} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc index 39f497e71612fc08a085e410edae73669fc9993a..696d5840e8ce39c1bf210b54b9f28ae83cf232c7 100644 --- a/tensorflow/core/kernels/cwise_ops_test.cc +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -231,14 +231,22 @@ BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF); Graph* BcastAdd(int rows, int cols, int dim) { Graph* g = new Graph(OpRegistry::Global()); - Tensor lhs(DT_FLOAT, TensorShape({rows, cols})); - lhs.flat().setRandom(); - TensorShape rhs_shape; - if (dim == 0) { + TensorShape lhs_shape, rhs_shape; + if (dim == 0) { // row + lhs_shape = TensorShape({rows, cols}); rhs_shape = TensorShape({rows, 1}); - } else { + } else if (dim == 1) { // col + lhs_shape = TensorShape({rows, cols}); rhs_shape = TensorShape({cols}); + } else if (dim == 2) { // cross_rc + lhs_shape = TensorShape({rows, 1}); + rhs_shape = TensorShape({1, cols}); + } else { // cross_cr + lhs_shape = TensorShape({1, cols}); + rhs_shape = TensorShape({rows, 1}); } + Tensor lhs(DT_FLOAT, lhs_shape); + lhs.flat().setRandom(); Tensor rhs(DT_FLOAT, rhs_shape); rhs.flat().setRandom(); test::graph::Binary(g, "Add", test::graph::Constant(g, lhs), @@ -298,5 +306,59 @@ BM_BCAST_ADD_COL_ALL(sycl); #undef BM_BCAST_ADD_COL_ALL #undef BM_BCAST_ADD_COL +#define BM_BCAST_ADD_CROSS_RC(DEVICE, R, C) \ + void BM_##DEVICE##_BcastAddCrossRC_R##R##_C##C(int iters, int arg) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast(iters) * rows * cols; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, BcastAdd(rows, cols, 2)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_BcastAddCrossRC_R##R##_C##C) \ + ->Arg(RowsAndColsArg(R, C)); + +#define BM_BCAST_ADD_CROSS_RC_ALL(DEVICE) \ + BM_BCAST_ADD_CROSS_RC(DEVICE, 512, 2048); \ + BM_BCAST_ADD_CROSS_RC(DEVICE, 512, 4096); \ + BM_BCAST_ADD_CROSS_RC(DEVICE, 2048, 512); \ + BM_BCAST_ADD_CROSS_RC(DEVICE, 4096, 512); +BM_BCAST_ADD_CROSS_RC_ALL(cpu); +#if GOOGLE_CUDA +BM_BCAST_ADD_CROSS_RC_ALL(gpu); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_BCAST_ADD_CROSS_RC_ALL(sycl); +#endif // TENSORFLOW_USE_SYCL +#undef BM_BCAST_ADD_CROSS_RC_ALL +#undef BM_BCAST_ADD_CROSS_RC + +#define BM_BCAST_ADD_CROSS_CR(DEVICE, R, C) \ + void BM_##DEVICE##_BcastAddCrossCR_R##R##_C##C(int iters, int arg) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast(iters) * rows * cols; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(float)); \ + test::Benchmark(#DEVICE, BcastAdd(rows, cols, 3)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_BcastAddCrossCR_R##R##_C##C) \ + ->Arg(RowsAndColsArg(R, C)); + +#define BM_BCAST_ADD_CROSS_CR_ALL(DEVICE) \ + BM_BCAST_ADD_CROSS_CR(DEVICE, 512, 2048); \ + BM_BCAST_ADD_CROSS_CR(DEVICE, 512, 4096); \ + BM_BCAST_ADD_CROSS_CR(DEVICE, 2048, 512); \ + BM_BCAST_ADD_CROSS_CR(DEVICE, 4096, 512); +BM_BCAST_ADD_CROSS_CR_ALL(cpu); +#if GOOGLE_CUDA +BM_BCAST_ADD_CROSS_CR_ALL(gpu); +#endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +BM_BCAST_ADD_CROSS_CR_ALL(sycl); +#endif // TENSORFLOW_USE_SYCL +#undef BM_BCAST_ADD_CROSS_CR_ALL +#undef BM_BCAST_ADD_CROSS_CR + } // namespace } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 1e3b0c231f35c12d2e9e23d8d503b3a7492ab676..253399c1e4ec7fe8edeeeee161ef3413d1dbea09 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -209,6 +209,19 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "generator_dataset_op", + srcs = ["generator_dataset_op.cc"], + deps = [ + ":captured_function", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + tf_kernel_library( name = "scan_dataset_op", srcs = ["scan_dataset_op.cc"], @@ -498,18 +511,6 @@ tf_kernel_library( ], ) -tf_kernel_library( - name = "unique_dataset_op", - srcs = ["unique_dataset_op.cc"], - deps = [ - ":dataset", - "//tensorflow/core:dataset_ops_op_lib", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - tf_kernel_library( name = "dataset_ops", deps = [ @@ -519,6 +520,7 @@ tf_kernel_library( ":dense_to_sparse_batch_dataset_op", ":filter_dataset_op", ":flat_map_dataset_op", + ":generator_dataset_op", ":group_by_window_dataset_op", ":interleave_dataset_op", ":iterator_ops", @@ -543,7 +545,6 @@ tf_kernel_library( ":tensor_dataset_op", ":tensor_queue_dataset_op", ":tensor_slice_dataset_op", - ":unique_dataset_op", ":zip_dataset_op", ], ) diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index f248f7897ffac8771a2e813265cccd410da4074d..dd61b7daee153bf2f3be3c72dd5c8e6032d0080b 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -33,7 +33,7 @@ Status CapturedFunction::Create( } CapturedFunction::~CapturedFunction() { - if (lib_ != nullptr) { + if (lib_ != nullptr && f_handle_ != kInvalidHandle) { lib_->ReleaseHandle(f_handle_).IgnoreError(); } } @@ -256,6 +256,62 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx, return frame.ConsumeRetvals(rets); } +Status CapturedFunction::Instantiate(IteratorContext* ctx) { + FunctionLibraryRuntime::Handle unused_handle; + TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle)); + mutex_lock l(mu_); + if (captured_runner_ == nullptr) { + captured_runner_ = *ctx->runner(); + } + return Status::OK(); +} + +Status CapturedFunction::RunInstantiated(const std::vector& args, + std::vector* rets) { + FunctionLibraryRuntime* lib; + FunctionLibraryRuntime::Handle handle; + std::function)>* runner; + { + tf_shared_lock l(mu_); + if (lib_ == nullptr) { + return errors::FailedPrecondition( + "`CapturedFunction::Instantiate()` must be called before a call to " + "`CapturedFunction::RunInstantiated()`."); + } + lib = lib_; + handle = f_handle_; + runner = &captured_runner_; + } + + FunctionLibraryRuntime::Options f_opts; + f_opts.step_id = CapturedFunction::generate_step_id(); + ScopedStepContainer step_container(f_opts.step_id, [lib](const string& name) { + lib->device()->resource_manager()->Cleanup(name).IgnoreError(); + }); + f_opts.step_container = &step_container; + f_opts.runner = runner; + // TODO(mrry): Add cancellation manager support to IteratorContext + // so that we can cancel running map functions. The local + // cancellation manager here is created so that we can run kernels + // (such as queue kernels) that depend on the non-nullness of + // `OpKernelContext::cancellation_manager()`, but additional effort + // will be required to plumb it through the `IteratorContext`. + CancellationManager c_mgr; + f_opts.cancellation_manager = &c_mgr; + + BorrowedArgsCallFrame frame(args, &captured_inputs_, ret_types_); + Notification n; + Status s; + + lib->Run(f_opts, handle, &frame, [&n, &s](Status func_status) { + s.Update(func_status); + n.Notify(); + }); + n.WaitForNotification(); + TF_RETURN_IF_ERROR(s); + return frame.ConsumeRetvals(rets); +} + void CapturedFunction::RunAsync(IteratorContext* ctx, std::vector&& args, std::vector* rets, diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index 32d2bc3aaebf440584934231a8555199026074ae..490f5cd1e3b6676decc6646df9dfb722524d58e8 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -64,6 +64,21 @@ class CapturedFunction { const std::vector& args, std::vector* rets); + // Explicitly instantiate this function for use in the given + // context. This method, and the context-less overload + // `RunInstantiated()` below can be useful for calling a captured + // function in cases where an `IteratorContext*` is not available + // (such as a destructor). + Status Instantiate(IteratorContext* ctx); + + // Synchronously runs the captured function on the given `args`, and stores + // the results in `*rets`. Prefer to use `Run()` or `RunAsync()` when + // possible. + // + // REQUIRES: `this->Instantiate()` must have been called before this method. + Status RunInstantiated(const std::vector& args, + std::vector* rets); + // Asynchronously runs the captured function on the given `args`, stores // the results in `*rets`, and calls the given `done` callback when the // function returns. This method takes ownership of the tensors in `args`, @@ -99,6 +114,7 @@ class CapturedFunction { FunctionLibraryRuntime::Handle f_handle_ GUARDED_BY(mu_); const std::vector captured_inputs_; DataTypeSlice ret_types_; + std::function)> captured_runner_ = nullptr; TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f1e441b91d0102b112523a46ac75ce415eacdd7 --- /dev/null +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -0,0 +1,201 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/captured_function.h" +#include "tensorflow/core/lib/random/random.h" + +namespace tensorflow { + +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class GeneratorDatasetOp : public DatasetOpKernel { + public: + explicit GeneratorDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("next_func", &next_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("finalize_func", &finalize_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + OpInputList init_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", + &init_func_other_args_input)); + std::vector init_func_other_args; + init_func_other_args.reserve(init_func_other_args_input.size()); + for (const Tensor& t : init_func_other_args_input) { + init_func_other_args.push_back(t); + } + std::unique_ptr init_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create( + init_func_, std::move(init_func_other_args), &init_func)); + + OpInputList next_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", + &next_func_other_args_input)); + std::vector next_func_other_args; + next_func_other_args.reserve(next_func_other_args_input.size()); + for (const Tensor& t : next_func_other_args_input) { + next_func_other_args.push_back(t); + } + std::unique_ptr next_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create( + next_func_, std::move(next_func_other_args), &next_func)); + + OpInputList finalize_func_other_args_input; + OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", + &finalize_func_other_args_input)); + std::vector finalize_func_other_args; + finalize_func_other_args.reserve(finalize_func_other_args_input.size()); + for (const Tensor& t : finalize_func_other_args_input) { + finalize_func_other_args.push_back(t); + } + std::unique_ptr finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + finalize_func_, std::move(finalize_func_other_args), + &finalize_func)); + + *output = + new Dataset(ctx, std::move(init_func), std::move(next_func), + std::move(finalize_func), output_types_, output_shapes_); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::unique_ptr init_func, + std::unique_ptr next_func, + std::unique_ptr finalize_func, + const DataTypeVector& output_types, + const std::vector& output_shapes) + : GraphDatasetBase(ctx), + init_func_(std::move(init_func)), + next_func_(std::move(next_func)), + finalize_func_(std::move(finalize_func)), + output_types_(output_types), + output_shapes_(output_shapes) {} + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Generator")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() override { return "GeneratorDatasetOp::Dataset"; } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + ~Iterator() override { + if (!finalized_) { + std::vector ignored; + Status s = + dataset()->finalize_func_->RunInstantiated(state_, &ignored); + if (!s.ok()) { + LOG(WARNING) + << "Error occurred when finalizing GeneratorDataset iterator: " + << s; + } + } + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + + if (!initialized_) { + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + // Explicitly instantiate the finalize function here so that + // we can invoke it in the destructor. + TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); + initialized_ = true; + } + + if (finalized_) { + *end_of_sequence = true; + return Status::OK(); + } + + Status s = dataset()->next_func_->RunWithBorrowedArgs(ctx, state_, + out_tensors); + if (s.ok()) { + *end_of_sequence = false; + } else if (errors::IsOutOfRange(s)) { + // `next_func` may deliberately raise `errors::OutOfRange` + // to indicate that we should terminate the iteration. + s = Status::OK(); + *end_of_sequence = true; + + // NOTE(mrry): We ignore any tensors returned by the + // finalize function. + std::vector ignored; + TF_RETURN_IF_ERROR( + dataset()->finalize_func_->RunInstantiated(state_, &ignored)); + finalized_ = true; + } + return s; + } + + private: + mutex mu_; + bool initialized_ GUARDED_BY(mu_) = false; + bool finalized_ GUARDED_BY(mu_) = false; + std::vector state_ GUARDED_BY(mu_); + }; + + const std::unique_ptr init_func_; + const std::unique_ptr next_func_; + const std::unique_ptr finalize_func_; + const DataTypeVector output_types_; + const std::vector output_shapes_; + }; + + DataTypeVector output_types_; + std::vector output_shapes_; + NameAttrList init_func_; + NameAttrList next_func_; + NameAttrList finalize_func_; +}; + +REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), + GeneratorDatasetOp); + +} // namespace + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index bc4426a9fdbab971a4e49d57ffcea6896fc037a7..33053b1bd9d7878016ebaf96b75c5c4b30130c4b 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -199,7 +199,14 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { } } ++num_outputs_consumed_; - return result->status; + if (errors::IsOutOfRange(result->status)) { + // `f` may deliberately raise `errors::OutOfRange` to indicate + // that we should terminate the iteration early. + *end_of_sequence = true; + return Status::OK(); + } else { + return result->status; + } } protected: diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 1cb533158bb5b8bd4b950192ce67e17c0f9d5447..d37086541dc4714162e00cc6d022b3bd300e3a1c 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -187,12 +187,11 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { } else { input_impl_.reset(); if (first_call) { - // If the first call to GetNext() fails because the end of - // sequence has been reached, we return an OutOfRange error to - // terminate the iteration. (Otherwise, this iterator would loop - // infinitely and never produce a value.) - return errors::OutOfRange( - "Attempted to repeat an empty dataset infinitely."); + // If the first call to GetNext() fails because the end + // of sequence has been reached, we terminate the + // iteration immediately. (Otherwise, this iterator + // would loop infinitely and never produce a value.) + return Status::OK(); } } } while (true); diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 1dde236c1711afd794ff397859631a48984b5ba8..2f6bf83da5d4f1d4b431e6849fd6571f56539dfe 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -104,13 +104,12 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { break; } if (first_call && dataset()->count_ == -1) { - // If the first call to GetNext() fails because the end of - // sequence has been reached, we return an OutOfRange error to - // terminate the iteration. (Otherwise, this iterator may loop - // infinitely and never produce a value.) + // If the first call to GetNext() fails because the end + // of sequence has been reached, we terminate the + // iteration immediately. (Otherwise, this iterator + // would loop infinitely and never produce a value.) *end_of_sequence = true; - return errors::OutOfRange( - "Attempted to repeat an empty dataset infinitely."); + return Status::OK(); } epoch_++; int64 n = slices_.back()->end; diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index b7d120a617849b2c1a48b38b959f9941eb8503ac..b4dcf0a74b336e6173843be233c370d624a9a8e2 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -91,15 +91,32 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument( "Number of channels must be 1, 3 or 4, was ", channels_)); + OP_REQUIRES(context, width > 0 && header_size >= 0, + errors::InvalidArgument("Width must be positive")); + OP_REQUIRES(context, header_size >= 0, + errors::InvalidArgument("header size must be nonnegative")); + + // The real requirement is < 2^31 minus some headers and channel data, + // so rounding down to something that's still ridiculously big. + OP_REQUIRES( + context, + (static_cast(width) * std::abs(static_cast(height))) < + static_cast(std::numeric_limits::max() / 8), + errors::InvalidArgument( + "Total possible pixel bytes must be less than 2^30")); + + const int32 abs_height = abs(height); + // there may be padding bytes when the width is not a multiple of 4 bytes // 8 * channels == bits per pixel const int row_size = (8 * channels_ * width + 31) / 32 * 4; - const int last_pixel_offset = - header_size + (abs(height) - 1) * row_size + (width - 1) * channels_; + const int64 last_pixel_offset = static_cast(header_size) + + (abs_height - 1) * row_size + + (width - 1) * channels_; // [expected file size] = [last pixel offset] + [last pixel size=channels] - const int expected_file_size = last_pixel_offset + channels_; + const int64 expected_file_size = last_pixel_offset + channels_; OP_REQUIRES( context, (expected_file_size <= input.size()), @@ -115,12 +132,12 @@ class DecodeBmpOp : public OpKernel { Tensor* output = nullptr; OP_REQUIRES_OK( context, context->allocate_output( - 0, TensorShape({abs(height), width, channels_}), &output)); + 0, TensorShape({abs_height, width, channels_}), &output)); const uint8* bmp_pixels = &img_bytes[header_size]; Decode(bmp_pixels, row_size, output->flat().data(), width, - abs(height), channels_, top_down); + abs_height, channels_, top_down); } uint8* Decode(const uint8* input, const int row_size, uint8* const output, diff --git a/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc index c9c97dc072c93e3ab840a8a9c9d81eadd2adaa3c..9a3b2303a3bf6718009b5055c4ef25464ec01136 100644 --- a/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/dense_update_functor_gpu.cu.cc @@ -57,6 +57,7 @@ struct DenseUpdate { template struct functor::DenseUpdate; \ template struct functor::DenseUpdate; TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); +TF_CALL_int64(DEFINE_GPU_KERNELS); #undef DEFINE_GPU_KERNELS #define DEFINE_GPU_KERNELS(T) \ diff --git a/tensorflow/core/kernels/dense_update_ops.cc b/tensorflow/core/kernels/dense_update_ops.cc index 6497c8f3719737ede2d261decd16f01c9854a7eb..0de97de20523ad54c08aa7b4190438c1da6ebde7 100644 --- a/tensorflow/core/kernels/dense_update_ops.cc +++ b/tensorflow/core/kernels/dense_update_ops.cc @@ -109,6 +109,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); AssignOpT); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -142,6 +143,7 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); Name("AssignSub").Device(DEVICE_GPU).TypeConstraint("T"), \ DenseUpdateOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // end GOOGLE_CUDA diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 1e9345828ad25fd18262e400a913cbc39ff09fae..94989089ec9cdf9314860b43f67691f39f33c31f 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -24,12 +24,12 @@ limitations under the License. #include "tensorflow/core/util/cuda_kernel_helper.h" #include "tensorflow/core/util/tensor_format.h" -#if !defined(_MSC_VER) -#define UNROLL _Pragma("unroll") -#define NOUNROLL _Pragma("nounroll") -#else +#if defined(_MSC_VER) && !defined(__clang__) #define UNROLL #define NOUNROLL +#else +#define UNROLL _Pragma("unroll") +#define NOUNROLL _Pragma("nounroll") #endif namespace tensorflow { @@ -52,13 +52,13 @@ EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dGPUSmall( // Returns whether depthwise convolution backward filter pass can be performed // using the faster ('Small') variant of the kernel. EIGEN_DEVICE_FUNC bool CanLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const DepthwiseArgs& args, const int block_rows) { + const DepthwiseArgs& args, const int block_height) { return args.depth_multiplier == 1 && args.stride == 1 && args.in_rows <= 32 && args.in_cols <= 32 && args.in_rows == args.out_rows && args.in_cols == args.out_cols && args.pad_rows >= 0 && args.pad_rows < args.filter_rows && args.pad_cols >= 0 && - args.pad_cols < args.filter_cols && block_rows <= args.in_rows && - args.filter_rows * args.filter_cols <= args.in_cols * block_rows; + args.pad_cols < args.filter_cols && block_height <= args.in_rows && + args.filter_rows * args.filter_cols <= args.in_cols * block_height; } // The DepthwiseConv2dGPUKernels perform either forward or backprop input @@ -72,72 +72,81 @@ template (0); - const int input_offset_temp = in_rows * OB; + const int input_offset_temp = in_height * batch; if (input_row_start >= 0 && input_col_start >= 0 && - input_row_end < in_rows && input_col_end < in_cols) { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; + input_row_end < in_height && input_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; const int input_offset = - in_d + in_depth * (in_c + in_cols * (in_r + input_offset_temp)); + in_channel + + in_depth * (in_col + in_width * (in_row + input_offset_temp)); const int filter_offset = multiplier + - depth_multiplier * (in_d + in_depth * (f_c + filter_offset_temp)); + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } } else { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { - const int in_c = input_col_start + f_c; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int in_col = input_col_start + filter_col; const int input_offset = - in_d + in_depth * (in_c + in_cols * (in_r + input_offset_temp)); + in_channel + + in_depth * (in_col + in_width * (in_row + input_offset_temp)); const int filter_offset = - multiplier + depth_multiplier * - (in_d + in_depth * (f_c + filter_offset_temp)); + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } @@ -157,8 +166,8 @@ __global__ void __launch_bounds__(1024, 2) // Backprop input direction is the same as forward direction with the filter // rotated by 180°. template + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); @@ -166,45 +175,47 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; - const int block_rows = blockDim.z; + assert(blockDim.x == kBlockDepth); + assert(blockDim.y == args.in_cols); + const int block_height = blockDim.z; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_size = block_rows * in_cols * kBlockSlices; - const int in_row_size = in_cols * in_depth; - const int in_size = in_rows * in_row_size; - const int in_increment = (in_cols - 1) * kBlockSlices; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int even_rows = kKnownEvenRows || (1 & ~in_rows); - const int tile_rows = in_rows + filter_rows - even_rows; - const int tile_row_size = tile_cols * kBlockSlices; - const int tile_size = tile_rows * tile_row_size; - const int tile_offset = block_rows * tile_row_size; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices; - const int in_blocks = batch_blocks * batches; + const int block_size = block_height * in_width * kBlockDepth; + const int in_row_size = in_width * in_depth; + const int in_size = in_height * in_row_size; + const int in_increment = (in_width - 1) * kBlockDepth; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int even_height = kKnownEvenHeight || (1 & ~in_height); + const int tile_height = in_height + filter_height - even_height; + const int tile_row_size = tile_width * kBlockDepth; + const int tile_size = tile_height * tile_row_size; + const int tile_offset = block_height * tile_row_size; + const int pad_offset = pad_height * tile_width + pad_width; + const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth; + const int in_blocks = batch_blocks * num_batches; const int tensor_offset = - kKnownEvenRows ? in_size / 2 : block_rows * in_row_size; + kKnownEvenHeight ? in_size / 2 : block_height * in_row_size; const int thread_depth = threadIdx.x; const int thread_col = threadIdx.y; const int thread_row = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; - const int thread_idx = thread_pix * kBlockSlices + thread_depth; + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_pix * kBlockDepth + thread_depth; // Initialize tile, in particular the padding. for (int i = thread_idx; i < tile_size; i += block_size) { @@ -216,32 +227,32 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( const int tensor_idx = thread_pix * in_depth + thread_depth; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; - const int data_idx = data_pix * kBlockSlices + thread_depth; + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = data_pix * kBlockDepth + thread_depth; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_pix = data_pix + pad_offset; - const int tile_idx = tile_pix * kBlockSlices + thread_depth; + const int tile_idx = tile_pix * kBlockDepth + thread_depth; - const int max_depth = in_depth - thread_depth; + const int max_channel = in_depth - thread_depth; const int filter_write_offset = thread_pix < filter_pixels ? tile_size + thread_idx : 0; const int filter_read_offset = tile_size + thread_depth + - (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockSlices); + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth); const bool skip_second = - !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; + !kKnownEvenHeight && thread_row + (in_height & 1) == block_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { const int batch = b / batch_blocks; - const int stack = b - batch * batch_blocks; + const int block = b - batch * batch_blocks; - const int start_depth = stack * kBlockSlices; - const int filter_offset = tensor_idx + start_depth; + const int start_channel = block * kBlockDepth; + const int filter_offset = tensor_idx + start_channel; const int inout_offset = batch * in_size + filter_offset; - const bool depth_in_range = start_depth < max_depth; + const bool channel_in_range = start_channel < max_channel; - if (depth_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -257,23 +268,23 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - if (depth_in_range) { + if (channel_in_range) { T sum1 = static_cast(0); T sum2 = static_cast(0); int shared_offset = data_idx; const T* filter_ptr = filter_read_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { if (kDirection == DIRECTION_BACKWARD) { - filter_ptr -= kBlockSlices; + filter_ptr -= kBlockDepth; } const T filter_value = *filter_ptr; const T* const tile_ptr = shared_offset + shared_data; sum1 += filter_value * tile_ptr[0]; sum2 += filter_value * tile_ptr[tile_offset]; - shared_offset += kBlockSlices; + shared_offset += kBlockDepth; if (kDirection == DIRECTION_FORWARD) { - filter_ptr += kBlockSlices; + filter_ptr += kBlockDepth; } } shared_offset += in_increment; @@ -297,20 +308,20 @@ template (0); if (input_row_start >= 0 && input_col_start >= 0 && - input_row_end < in_rows && input_col_end < in_cols) { + input_row_end < in_height && input_col_end < in_width) { // Loop that doesn't need to check for boundary conditions. - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; const int input_offset = - (input_offset_temp) + (in_r * in_cols) + in_c; + (input_offset_temp) + (in_row * in_width) + in_col; const int filter_offset = multiplier + - depth_multiplier * (in_d + in_depth * (f_c + filter_offset_temp)); + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } } else { // Loop that needs to check for boundary conditions. - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = input_row_start + f_r; - const int filter_offset_temp = filter_cols * f_r; - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = input_col_start + f_c; - // TODO(vrv): the in_r check can be done outside of this loop; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = input_row_start + filter_row; + const int filter_offset_temp = filter_width * filter_row; + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = input_col_start + filter_col; + // TODO(vrv): the in_row check can be done outside of this loop; // benchmark both methods to determine the better decision. - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { - const int in_c = input_col_start + f_c; + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int in_col = input_col_start + filter_col; // input_offset_temp indexes into the start of memory // where the spatial data starts. const int input_offset = - (input_offset_temp) + (in_r * in_cols) + in_c; + (input_offset_temp) + (in_row * in_width) + in_col; const int filter_offset = - multiplier + depth_multiplier * - (in_d + in_depth * (f_c + filter_offset_temp)); + multiplier + + depth_multiplier * + (in_channel + in_depth * (filter_col + filter_offset_temp)); sum += ldg(input + input_offset) * ldg(filter + filter_offset); } } @@ -427,8 +446,8 @@ __global__ void __launch_bounds__(1024, 2) // Backprop input direction is the same as forward direction with the filter // rotated by 180°. template + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const DepthwiseArgs args, const T* input, const T* filter, T* output) { assert(CanLaunchDepthwiseConv2dGPUSmall(args)); @@ -436,43 +455,45 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; // Fixed blockDim.z, tailored for maximum grid size for images of size 16x16. - const int block_rows = blockDim.y; + assert(blockDim.x == args.in_cols); + assert(blockDim.z == kBlockDepth); + const int block_height = blockDim.y; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_pixels = in_cols * block_rows; - const int block_size = block_pixels * kBlockSlices; - const int in_pixels = in_cols * in_rows; - const int in_increment = in_cols - 1; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int even_rows = kKnownEvenRows || (1 & ~in_rows); - const int tile_rows = in_rows + filter_rows - even_rows; - const int tile_pixels = tile_cols * tile_rows; - const int tile_size = tile_pixels * kBlockSlices; - const int tile_offset = block_rows * tile_cols; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int in_slices = in_depth * batches; - const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; + const int block_pixels = in_width * block_height; + const int block_size = block_pixels * kBlockDepth; + const int in_pixels = in_width * in_height; + const int in_increment = in_width - 1; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int even_height = kKnownEvenHeight || (1 & ~in_height); + const int tile_height = in_height + filter_height - even_height; + const int tile_pixels = tile_width * tile_height; + const int tile_size = tile_pixels * kBlockDepth; + const int tile_offset = block_height * tile_width; + const int pad_offset = pad_height * tile_width + pad_width; + const int in_total_depth = in_depth * num_batches; + const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth; const int thread_col = threadIdx.x; const int thread_row = threadIdx.y; const int thread_depth = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; + const int thread_pix = thread_row * in_width + thread_col; const int thread_idx = thread_depth * block_pixels + thread_pix; // Initialize tile, in particular the padding. @@ -485,33 +506,33 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( const int tensor_idx = thread_depth * in_pixels + thread_pix; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; + const int data_pix = thread_row * tile_width + thread_col; const int data_idx = thread_depth * tile_pixels + data_pix; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_idx = data_idx + pad_offset; // Filter is always in HWCK format, irrespective of the input/output format. - const int filter_pix = thread_idx / kBlockSlices; - const int filter_depth = thread_idx % kBlockSlices; + const int filter_pix = thread_idx / kBlockDepth; + const int filter_channel = thread_idx % kBlockDepth; const int filter_idx = filter_pix * in_depth; - const int max_slice = in_slices - thread_depth; + const int max_channel = in_total_depth - thread_depth; const int filter_write_offset = filter_pix < filter_pixels ? tile_size + thread_idx : 0; const int filter_read_offset = tile_size + thread_depth + - (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockSlices); + (kDirection == DIRECTION_FORWARD ? 0 : filter_pixels * kBlockDepth); const bool skip_second = - !kKnownEvenRows && thread_row + (in_rows & 1) == block_rows; + !kKnownEvenHeight && thread_row + (in_height & 1) == block_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { - const int slice = b * kBlockSlices; + const int channel = b * kBlockDepth; - const int inout_offset = slice * in_pixels + tensor_idx; - const bool slice_in_range = slice < max_slice; + const int inout_offset = channel * in_pixels + tensor_idx; + const bool channel_in_range = channel < max_channel; - if (slice_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -521,22 +542,23 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( } if (filter_write_offset != 0) { - const int filter_offset = filter_idx + (slice + filter_depth) % in_depth; + const int filter_offset = + filter_idx + (channel + filter_channel) % in_depth; shared_data[filter_write_offset] = ldg(filter_offset + filter); } // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - if (slice_in_range) { + if (channel_in_range) { T sum1 = static_cast(0); T sum2 = static_cast(0); int shared_offset = data_idx; const T* filter_ptr = filter_read_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { if (kDirection == DIRECTION_BACKWARD) { - filter_ptr -= kBlockSlices; + filter_ptr -= kBlockDepth; } const T filter_value = *filter_ptr; const T* const tile_ptr = shared_offset + shared_data; @@ -544,7 +566,7 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( sum2 += filter_value * tile_ptr[tile_offset]; ++shared_offset; if (kDirection == DIRECTION_FORWARD) { - filter_ptr += kBlockSlices; + filter_ptr += kBlockDepth; } } shared_offset += in_increment; @@ -562,133 +584,148 @@ __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dGPUKernelNCHWSmall( } template -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth, + bool kKnownEvenHeight> +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { - const int block_rows = (args.in_rows + 1) / 2; + const int block_height = (args.in_rows + 1) / 2; dim3 block_dim; + int block_count; void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); - if (data_format == FORMAT_NHWC) { - block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - kernel = DepthwiseConv2dGPUKernelNHWCSmall; - } else if (data_format == FORMAT_NCHW) { - block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - kernel = DepthwiseConv2dGPUKernelNCHWSmall; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + block_dim = dim3(kBlockDepth, args.in_cols, block_height); + block_count = + args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; + kernel = + DepthwiseConv2dGPUKernelNHWCSmall; + break; + case FORMAT_NCHW: + block_dim = dim3(args.in_cols, block_height, kBlockDepth); + block_count = + DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; + kernel = + DepthwiseConv2dGPUKernelNCHWSmall; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } - const int tile_cols = args.in_cols + args.filter_cols - 1; - const int tile_rows = block_rows * 2 + args.filter_rows - 1; - const int tile_pixels = tile_rows * tile_cols; + const int tile_width = args.in_cols + args.filter_cols - 1; + const int tile_height = block_height * 2 + args.filter_rows - 1; + const int tile_pixels = tile_height * tile_width; const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - kBlockSlices * (tile_pixels + filter_pixels) * sizeof(T); - const int num_outputs = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = - GetCudaLaunchConfig(num_outputs, d, kernel, shared_memory_size, - block_dim.x * block_dim.y * block_dim.z); - kernel<<>>( - args, input, filter, output); + kBlockDepth * (tile_pixels + filter_pixels) * sizeof(T); + const int num_outputs = args.out_rows * args.out_cols * block_count; + CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + num_outputs, device, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + kernel<<>>(args, input, filter, output); } template -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, + int kKnownFilterWidth, int kKnownFilterHeight, int kBlockDepth> +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { if (args.in_rows & 1) { LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + kKnownFilterHeight, kBlockDepth, false>( + device, args, input, filter, output, data_format); } else { LaunchDepthwiseConv2dGPUSmall( - d, args, input, filter, output, data_format); + kKnownFilterHeight, kBlockDepth, true>( + device, args, input, filter, output, data_format); } } template -void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& d, +void LaunchDepthwiseConv2dGPUSmall(const GpuDevice& device, const DepthwiseArgs& args, const T* input, const T* filter, T* output, TensorFormat data_format) { - // Maximize (power of two) kBlockSlices while keeping a block within 1024 + // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). const int block_pixels = (args.in_rows + 1) / 2 * args.in_cols; if (block_pixels > 256) { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight, 2>( + device, args, input, filter, output, data_format); } else if (block_pixels > 128) { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight, 4>( + device, args, input, filter, output, data_format); } else { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight, 8>( + device, args, input, filter, output, data_format); } } template -void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs& args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); - if (data_format == FORMAT_NHWC) { - kernel = - DepthwiseConv2dGPUKernelNHWC; - } else if (data_format == FORMAT_NCHW) { - kernel = - DepthwiseConv2dGPUKernelNCHW; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + kernel = + DepthwiseConv2dGPUKernelNHWC; + break; + case FORMAT_NCHW: + kernel = + DepthwiseConv2dGPUKernelNCHW; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } const int num_outputs = args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = GetCudaLaunchConfig(num_outputs, d, kernel, 0, 0); + CudaLaunchConfig config = + GetCudaLaunchConfig(num_outputs, device, kernel, 0, 0); // The compile-time constant version runs faster with a single block. const int max_block_count = kKnownFilterWidth < 0 || kKnownFilterHeight < 0 || kKnownDepthMultiplier < 0 ? std::numeric_limits::max() - : d.getNumCudaMultiProcessors(); + : device.getNumCudaMultiProcessors(); kernel<<>>(args, input, filter, - output, num_outputs); + config.thread_per_block, 0, device.stream()>>>(args, input, filter, + output, num_outputs); } template -void LaunchDepthwiseConv2dGPU(const GpuDevice& d, const DepthwiseArgs& args, - const T* input, const T* filter, T* output, +void LaunchDepthwiseConv2dGPU(const GpuDevice& device, + const DepthwiseArgs& args, const T* input, + const T* filter, T* output, TensorFormat data_format) { if (args.depth_multiplier == 1) { if (CanLaunchDepthwiseConv2dGPUSmall(args)) { LaunchDepthwiseConv2dGPUSmall(d, args, input, filter, - output, data_format); + kKnownFilterHeight>( + device, args, input, filter, output, data_format); return; } LaunchDepthwiseConv2dGPU( - d, args, input, filter, output, data_format); + device, args, input, filter, output, data_format); } else { LaunchDepthwiseConv2dGPU( - d, args, input, filter, output, data_format); + device, args, input, filter, output, data_format); } } @@ -699,12 +736,12 @@ void LaunchDepthwiseConvOp::operator()(OpKernelContext* ctx, const T* input, const T* filter, T* output, TensorFormat data_format) { - const GpuDevice& d = ctx->eigen_device(); + const GpuDevice& device = ctx->eigen_device(); if (args.filter_rows == 3 && args.filter_cols == 3) { - LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + LaunchDepthwiseConv2dGPU(device, args, input, filter, output, data_format); } else { - LaunchDepthwiseConv2dGPU(d, args, input, filter, output, + LaunchDepthwiseConv2dGPU(device, args, input, filter, output, data_format); } auto stream = ctx->op_device_context()->stream(); @@ -725,59 +762,65 @@ __global__ void __launch_bounds__(640, 2) const T* out_backprop, const T* filter, T* in_backprop, int num_in_backprop) { - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; const int depth_multiplier = kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - const int out_rows = args.out_rows; - const int out_cols = args.out_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; const int out_depth = args.out_depth; CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) { // Compute the indexes of this thread in the output. - const int in_d = thread_id % in_depth; - const int in_c = (thread_id / in_depth) % in_cols; - const int in_r = (thread_id / in_depth / in_cols) % in_rows; - const int b = thread_id / in_depth / in_cols / in_rows; + const int in_channel = thread_id % in_depth; + const int in_col = (thread_id / in_depth) % in_width; + const int in_row = (thread_id / in_depth / in_width) % in_height; + const int batch = thread_id / in_depth / in_width / in_height; T sum = static_cast(0); - const int out_r_start = - tf_max(0, (in_r - filter_rows + pad_rows + stride) / stride); - const int out_r_end = tf_min(out_rows - 1, (in_r + pad_rows) / stride); - const int out_c_start = - tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride); - const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride); - - NOUNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { - const int f_r = in_r + pad_rows - out_r * stride; + const int out_row_start = + tf_max(0, (in_row - filter_height + pad_height + stride) / stride); + const int out_row_end = + tf_min(out_height - 1, (in_row + pad_height) / stride); + const int out_col_start = + tf_max(0, (in_col - filter_width + pad_width + stride) / stride); + const int out_col_end = + tf_min(out_width - 1, (in_col + pad_width) / stride); + + NOUNROLL for (int out_row = out_row_start; out_row <= out_row_end; + ++out_row) { + const int filter_row = in_row + pad_height - out_row * stride; const int temp_out_backprop_offset = - out_depth * out_cols * (out_r + out_rows * b); - const int temp_filter_offset = filter_cols * f_r; - NOUNROLL for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { - const int f_c = in_c + pad_cols - out_c * stride; + out_depth * out_width * (out_row + out_height * batch); + const int temp_filter_offset = filter_width * filter_row; + NOUNROLL for (int out_col = out_col_start; out_col <= out_col_end; + ++out_col) { + const int filter_col = in_col + pad_width - out_col * stride; int filter_offset = - depth_multiplier * (in_d + in_depth * (f_c + temp_filter_offset)); + depth_multiplier * + (in_channel + in_depth * (filter_col + temp_filter_offset)); const int out_backprop_offset = - out_depth * out_c + temp_out_backprop_offset; + out_depth * out_col + temp_out_backprop_offset; #pragma unroll 6 for (int i = 0; i < depth_multiplier; ++i) { sum += ldg(out_backprop + out_backprop_offset + - in_d * depth_multiplier + i) * + in_channel * depth_multiplier + i) * ldg(filter + filter_offset + i); } } } const int in_backprop_offset = - in_d + in_depth * (in_c + in_cols * (in_r + in_rows * b)); + in_channel + + in_depth * (in_col + in_width * (in_row + in_height * batch)); in_backprop[in_backprop_offset] = sum; } } @@ -789,98 +832,107 @@ __global__ void __launch_bounds__(640, 2) const T* out_backprop, const T* filter, T* in_backprop, int num_in_backprop) { - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; const int depth_multiplier = kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - const int out_rows = args.out_rows; - const int out_cols = args.out_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; const int out_depth = args.out_depth; // TODO(vrv): Consider assigning threads to output and using // atomics for accumulation, similar to the filter case. CUDA_1D_KERNEL_LOOP(thread_id, num_in_backprop) { // Compute the indexes of this thread in the input. - const int in_c = thread_id % in_cols; - const int in_r = (thread_id / in_cols) % in_rows; - const int in_d = (thread_id / in_cols / in_rows) % in_depth; - const int b = thread_id / in_depth / in_cols / in_rows; + const int in_col = thread_id % in_width; + const int in_row = (thread_id / in_width) % in_height; + const int in_channel = (thread_id / in_width / in_height) % in_depth; + const int batch = thread_id / in_depth / in_width / in_height; T sum = static_cast(0); - const int out_d_start = in_d * depth_multiplier; - const int out_d_end = out_d_start + depth_multiplier; - - const int out_r_start = - tf_max(0, (in_r - filter_rows + pad_rows + stride) / stride); - const int out_r_end = tf_min(out_rows - 1, (in_r + pad_rows) / stride); - const int out_c_start = - tf_max(0, (in_c - filter_cols + pad_cols + stride) / stride); - const int out_c_end = tf_min(out_cols - 1, (in_c + pad_cols) / stride); - - UNROLL for (int out_d = out_d_start; out_d < out_d_end; ++out_d) { - UNROLL for (int out_r = out_r_start; out_r <= out_r_end; ++out_r) { - const int f_r = in_r + pad_rows - out_r * stride; - const int filter_dm = out_d - out_d_start; - - const int temp_filter_offset = filter_cols * f_r; - for (int out_c = out_c_start; out_c <= out_c_end; ++out_c) { - const int f_c = in_c + pad_cols - out_c * stride; + const int out_channel_start = in_channel * depth_multiplier; + const int out_channel_end = out_channel_start + depth_multiplier; + + const int out_row_start = + tf_max(0, (in_row - filter_height + pad_height + stride) / stride); + const int out_row_end = + tf_min(out_height - 1, (in_row + pad_height) / stride); + const int out_col_start = + tf_max(0, (in_col - filter_width + pad_width + stride) / stride); + const int out_col_end = + tf_min(out_width - 1, (in_col + pad_width) / stride); + + UNROLL for (int out_channel = out_channel_start; + out_channel < out_channel_end; ++out_channel) { + UNROLL for (int out_row = out_row_start; out_row <= out_row_end; + ++out_row) { + const int filter_row = in_row + pad_height - out_row * stride; + const int filter_dm = out_channel - out_channel_start; + + const int temp_filter_offset = filter_width * filter_row; + for (int out_col = out_col_start; out_col <= out_col_end; ++out_col) { + const int filter_col = in_col + pad_width - out_col * stride; const int filter_offset = - filter_dm + args.depth_multiplier * - (in_d + in_depth * (f_c + temp_filter_offset)); + filter_dm + + args.depth_multiplier * + (in_channel + in_depth * (filter_col + temp_filter_offset)); const int out_backprop_offset = - (b * out_depth * out_rows * out_cols) + - (out_d * out_rows * out_cols) + (out_r * out_cols) + (out_c); + (batch * out_depth * out_height * out_width) + + (out_channel * out_height * out_width) + (out_row * out_width) + + (out_col); sum += ldg(out_backprop + out_backprop_offset) * ldg(filter + filter_offset); } } } - const int in_backprop_offset = (b * in_rows * in_cols * in_depth) + - (in_d * in_rows * in_cols) + - (in_r * in_cols) + (in_c); + const int in_backprop_offset = (batch * in_height * in_width * in_depth) + + (in_channel * in_height * in_width) + + (in_row * in_width) + (in_col); in_backprop[in_backprop_offset] = sum; } } template -void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, +void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); - if (data_format == FORMAT_NHWC) { - kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else if (data_format == FORMAT_NCHW) { - kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + kernel = DepthwiseConv2dBackpropInputGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW: + kernel = DepthwiseConv2dBackpropInputGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } const int num_in_backprop = args.batch * args.in_rows * args.in_cols * args.in_depth; CudaLaunchConfig config = - GetCudaLaunchConfig(num_in_backprop, d, kernel, 0, 0); - kernel<<>>( + GetCudaLaunchConfig(num_in_backprop, device, kernel, 0, 0); + kernel<<>>( args, out_backprop, filter, in_backprop, num_in_backprop); } template -void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, +void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, @@ -889,17 +941,17 @@ void LaunchDepthwiseConv2dBackpropInputGPU(const GpuDevice& d, if (CanLaunchDepthwiseConv2dGPUSmall(args)) { LaunchDepthwiseConv2dGPUSmall( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); return; } LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } } @@ -908,13 +960,13 @@ template void LaunchDepthwiseConvBackpropInputOp::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* filter, T* in_backprop, TensorFormat data_format) { - const GpuDevice& d = ctx->eigen_device(); + const GpuDevice& device = ctx->eigen_device(); if (args.filter_rows == 3 && args.filter_cols == 3) { LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropInputGPU( - d, args, out_backprop, filter, in_backprop, data_format); + device, args, out_backprop, filter, in_backprop, data_format); } auto stream = ctx->op_device_context()->stream(); OP_REQUIRES(ctx, stream->ok(), @@ -936,75 +988,85 @@ __global__ void __launch_bounds__(640, 2) const T* input, T* filter_backprop, int num_out_backprop) { - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; const int depth_multiplier = kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - const int out_rows = args.out_rows; - const int out_cols = args.out_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; const int out_depth = args.out_depth; CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) { // Compute the indexes of this thread in the output. - const int out_d = thread_id % out_depth; - const int out_c = (thread_id / out_depth) % out_cols; - const int out_r = (thread_id / out_depth / out_cols) % out_rows; - const int b = thread_id / out_depth / out_cols / out_rows; + const int out_channel = thread_id % out_depth; + const int out_col = (thread_id / out_depth) % out_width; + const int out_row = (thread_id / out_depth / out_width) % out_height; + const int batch = thread_id / out_depth / out_width / out_height; // Compute the input depth and the index of depth multiplier. - const int in_d = out_d / depth_multiplier; - const int dm = out_d % depth_multiplier; + const int in_channel = out_channel / depth_multiplier; + const int dm = out_channel % depth_multiplier; // Decide if all input is valid, if yes, we can skip the boundary checks // for each input. - const int in_r_start = out_r * stride - pad_rows; - const int in_c_start = out_c * stride - pad_cols; - const int in_r_end = in_r_start + filter_rows; - const int in_c_end = in_c_start + filter_cols; + const int in_row_start = out_row * stride - pad_height; + const int in_col_start = out_col * stride - pad_width; + const int in_row_end = in_row_start + filter_height; + const int in_col_end = in_col_start + filter_width; const int out_backprop_offset = - out_d + out_depth * (out_c + out_cols * (out_r + out_rows * b)); + out_channel + + out_depth * (out_col + out_width * (out_row + out_height * batch)); const T out_bp = ldg(out_backprop + out_backprop_offset); - if (in_r_start >= 0 && in_c_start >= 0 && in_r_end < in_rows && - in_c_end < in_cols) { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height && + in_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = in_cols * (in_r + in_rows * b); - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; + const int input_offset_temp = in_width * (in_row + in_height * batch); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; - const int input_offset = in_d + in_depth * (in_c + input_offset_temp); + const int input_offset = + in_channel + in_depth * (in_col + input_offset_temp); T partial_sum = ldg(input + input_offset) * out_bp; - T* addr = filter_backprop + - (dm + depth_multiplier * - (in_d + in_depth * (f_c + filter_cols * f_r))); + T* addr = + filter_backprop + + (dm + depth_multiplier * + (in_channel + + in_depth * (filter_col + filter_width * filter_row))); CudaAtomicAdd(addr, partial_sum); } } } else { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = in_cols * (in_r + in_rows * b); - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; - const int addr_temp = filter_cols * f_r; - - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { + const int input_offset_temp = in_width * (in_row + in_height * batch); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int addr_temp = filter_width * filter_row; + + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { const int input_offset = - in_d + in_depth * (in_c + input_offset_temp); + in_channel + in_depth * (in_col + input_offset_temp); T partial_sum = ldg(input + input_offset) * out_bp; T* addr = filter_backprop + - (dm + depth_multiplier * (in_d + in_depth * (f_c + addr_temp))); + (dm + depth_multiplier * + (in_channel + in_depth * (filter_col + addr_temp))); // Potentially many threads can add to the same address so we have // to use atomic add here. // TODO(jmchen): If atomic add turns out to be slow, we can: @@ -1048,9 +1110,9 @@ __device__ __forceinline__ T WarpSumReduce(T val) { // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed // up in global memory using atomics. // Requirements: threads per block must be multiple of 32 and <= launch_bounds, -// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockSlices. +// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. template + int kBlockDepth, int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { @@ -1059,40 +1121,42 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = blockDim.y; // slower (see b/62280718): args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = blockDim.y; // slower (see b/62280718): args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; - const int block_rows = blockDim.z; + assert(blockDim.x == kBlockDepth); + assert(blockDim.y == args.in_cols); + const int block_height = blockDim.z; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_size = block_rows * in_cols * kBlockSlices; + const int block_size = block_height * in_width * kBlockDepth; assert((block_size & 31) == 0); - const int in_row_size = in_cols * in_depth; - const int in_size = in_rows * in_row_size; - const int in_increment = (in_cols - 1) * kBlockSlices; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int tile_rows = 2 * block_rows + filter_rows - 1; - const int tile_row_size = tile_cols * kBlockSlices; - const int tile_size = tile_rows * tile_row_size; - const int tile_offset = block_rows * tile_row_size; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int batch_blocks = (in_depth + kBlockSlices - 1) / kBlockSlices; - const int in_blocks = batch_blocks * batches; - const int tensor_offset = block_rows * in_row_size; + const int in_row_size = in_width * in_depth; + const int in_size = in_height * in_row_size; + const int in_increment = (in_width - 1) * kBlockDepth; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int tile_height = 2 * block_height + filter_height - 1; + const int tile_row_size = tile_width * kBlockDepth; + const int tile_size = tile_height * tile_row_size; + const int tile_offset = block_height * tile_row_size; + const int pad_offset = pad_height * tile_width + pad_width; + const int batch_blocks = (in_depth + kBlockDepth - 1) / kBlockDepth; + const int in_blocks = batch_blocks * num_batches; + const int tensor_offset = block_height * in_row_size; // The accumulator has a fixed number of pixels that can be reduced by one - // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written. - assert(kAccumPixels * 64 >= in_rows * in_cols * kBlockSlices); - const int accum_increment = kAccumPixels * kBlockSlices; + // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written. + assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth); + const int accum_increment = kAccumPixels * kBlockDepth; const int accum_size = filter_pixels * accum_increment; const int thread_depth = threadIdx.x; @@ -1100,8 +1164,8 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int thread_row = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; - const int thread_idx = thread_pix * kBlockSlices + thread_depth; + const int thread_pix = thread_row * in_width + thread_col; + const int thread_idx = thread_pix * kBlockDepth + thread_depth; // Initialize tile, in particular the padding and accumulator. for (int i = thread_idx; i < tile_size + accum_size; i += block_size) { @@ -1113,31 +1177,31 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const int tensor_idx = thread_pix * in_depth + thread_depth; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; - const int data_idx = data_pix * kBlockSlices + thread_depth; + const int data_pix = thread_row * tile_width + thread_col; + const int data_idx = data_pix * kBlockDepth + thread_depth; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_pix = data_pix + pad_offset; - const int tile_idx = tile_pix * kBlockSlices + thread_depth; + const int tile_idx = tile_pix * kBlockDepth + thread_depth; - // Position in accumulator (kBlockSlices per warp, depth major). - const int accum_pix = thread_pix / (32 / kBlockSlices); + // Position in accumulator (kBlockDepth per warp, depth major). + const int accum_pix = thread_pix / (32 / kBlockDepth); const int accum_idx = thread_depth * kAccumPixels + accum_pix; - const int max_depth = in_depth - thread_depth; + const int max_channel = in_depth - thread_depth; const int accum_offset = tile_size + accum_idx; - const bool skip_second = block_rows + thread_row >= in_rows; + const bool skip_second = block_height + thread_row >= in_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { const int batch = b / batch_blocks; - const int stack = b - batch * batch_blocks; + const int block = b - batch * batch_blocks; - const int start_depth = stack * kBlockSlices; - const int filter_offset = tensor_idx + start_depth; + const int start_channel = block * kBlockDepth; + const int filter_offset = tensor_idx + start_channel; const int inout_offset = batch * in_size + filter_offset; - const bool depth_in_range = start_depth < max_depth; + const bool channel_in_range = start_channel < max_channel; - if (depth_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -1148,26 +1212,26 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - unsigned active_threads = CudaBallotSync(kCudaWarpAll, depth_in_range); + unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range); - if (depth_in_range) { + if (channel_in_range) { const T* const out_ptr = inout_offset + output; const T out1 = ldg(out_ptr); const T out2 = skip_second ? T(0) : ldg(tensor_offset + out_ptr); int shared_offset = data_idx; T* accum_ptr = accum_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { const T* const tile_ptr = shared_offset + shared_data; T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. - for (int delta = 16; delta >= kBlockSlices; delta /= 2) { + for (int delta = 16; delta >= kBlockDepth; delta /= 2) { val += CudaShuffleXorSync(active_threads, val, delta); } - if (!(thread_idx & 32 - kBlockSlices) /* lane_idx < kBlockSlices */) { + if (!(thread_idx & 32 - kBlockDepth) /* lane_idx < kBlockDepth */) { *accum_ptr = val; } - shared_offset += kBlockSlices; + shared_offset += kBlockDepth; accum_ptr += accum_increment; } shared_offset += in_increment; @@ -1180,10 +1244,10 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { const int filter_idx = i / kAccumPixels; - const int filter_pix = filter_idx / kBlockSlices; - const int filter_depth = filter_idx % kBlockSlices + start_depth; - const int filter_offset = filter_pix * in_depth + filter_depth; - if (filter_depth < in_depth) { + const int filter_pix = filter_idx / kBlockDepth; + const int filter_channel = filter_idx % kBlockDepth + start_channel; + const int filter_offset = filter_pix * in_depth + filter_channel; + if (filter_channel < in_depth) { T val = accum_data[i]; // Warp-accumulate the pixels of the same depth from the accumulator. val = WarpSumReduce(val); @@ -1204,81 +1268,90 @@ __global__ void __launch_bounds__(640, 2) const T* input, T* filter_backprop, int num_out_backprop) { - const int in_rows = args.in_rows; - const int in_cols = args.in_cols; + const int in_height = args.in_rows; + const int in_width = args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; const int depth_multiplier = kKnownDepthMultiplier < 0 ? args.depth_multiplier : kKnownDepthMultiplier; const int stride = args.stride; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; - const int out_rows = args.out_rows; - const int out_cols = args.out_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; + const int out_height = args.out_rows; + const int out_width = args.out_cols; const int out_depth = args.out_depth; CUDA_1D_KERNEL_LOOP(thread_id, num_out_backprop) { // Compute the indexes of this thread in the output. - const int out_c = thread_id % out_cols; - const int out_r = (thread_id / out_cols) % out_rows; - const int out_d = (thread_id / out_cols / out_rows) % out_depth; + const int out_col = thread_id % out_width; + const int out_row = (thread_id / out_width) % out_height; + const int out_channel = (thread_id / out_width / out_height) % out_depth; - const int b = thread_id / out_depth / out_cols / out_rows; + const int batch = thread_id / out_depth / out_width / out_height; // Compute the input depth and the index of depth multiplier. - const int in_d = out_d / depth_multiplier; - const int dm = out_d % depth_multiplier; + const int in_channel = out_channel / depth_multiplier; + const int dm = out_channel % depth_multiplier; // Decide if all input is valid, if yes, we can skip the boundary checks // for each input. - const int in_r_start = out_r * stride - pad_rows; - const int in_c_start = out_c * stride - pad_cols; - const int in_r_end = in_r_start + filter_rows; - const int in_c_end = in_c_start + filter_cols; + const int in_row_start = out_row * stride - pad_height; + const int in_col_start = out_col * stride - pad_width; + const int in_row_end = in_row_start + filter_height; + const int in_col_end = in_col_start + filter_width; - const int out_backprop_offset = (b * out_depth * out_rows * out_cols) + - (out_d * out_rows * out_cols) + - (out_r * out_cols) + (out_c); + const int out_backprop_offset = + (batch * out_depth * out_height * out_width) + + (out_channel * out_height * out_width) + (out_row * out_width) + + (out_col); const T out_bp = ldg(out_backprop + out_backprop_offset); - if (in_r_start >= 0 && in_c_start >= 0 && in_r_end < in_rows && - in_c_end < in_cols) { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + if (in_row_start >= 0 && in_col_start >= 0 && in_row_end < in_height && + in_col_end < in_width) { + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = (b * in_depth * in_rows * in_cols) + - (in_d * in_rows * in_cols) + - (in_r * in_cols); - - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; - const int input_offset = input_offset_temp + in_c; + const int input_offset_temp = + (batch * in_depth * in_height * in_width) + + (in_channel * in_height * in_width) + (in_row * in_width); + + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int input_offset = input_offset_temp + in_col; T partial_sum = ldg(input + input_offset) * out_bp; - T* addr = filter_backprop + - (dm + depth_multiplier * - (in_d + in_depth * (f_c + filter_cols * f_r))); + T* addr = + filter_backprop + + (dm + depth_multiplier * + (in_channel + + in_depth * (filter_col + filter_width * filter_row))); CudaAtomicAdd(addr, partial_sum); } } } else { - UNROLL for (int f_r = 0; f_r < filter_rows; ++f_r) { - const int in_r = in_r_start + f_r; + UNROLL for (int filter_row = 0; filter_row < filter_height; + ++filter_row) { + const int in_row = in_row_start + filter_row; // Avoid repeated computation. - const int input_offset_temp = (b * in_depth * in_rows * in_cols) + - (in_d * in_rows * in_cols) + - (in_r * in_cols); - UNROLL for (int f_c = 0; f_c < filter_cols; ++f_c) { - const int in_c = in_c_start + f_c; - const int addr_temp = filter_cols * f_r; - - if (in_r >= 0 && in_r < in_rows && in_c >= 0 && in_c < in_cols) { - const int input_offset = input_offset_temp + in_c; + const int input_offset_temp = + (batch * in_depth * in_height * in_width) + + (in_channel * in_height * in_width) + (in_row * in_width); + UNROLL for (int filter_col = 0; filter_col < filter_width; + ++filter_col) { + const int in_col = in_col_start + filter_col; + const int addr_temp = filter_width * filter_row; + + if (in_row >= 0 && in_row < in_height && in_col >= 0 && + in_col < in_width) { + const int input_offset = input_offset_temp + in_col; T partial_sum = ldg(input + input_offset) * out_bp; T* addr = filter_backprop + - (dm + depth_multiplier * (in_d + in_depth * (f_c + addr_temp))); + (dm + depth_multiplier * + (in_channel + in_depth * (filter_col + addr_temp))); // Potentially many threads can add to the same address so we have // to use atomic add here. // TODO(jmchen): If atomic add turns out to be slow, we can: @@ -1307,9 +1380,9 @@ __global__ void __launch_bounds__(640, 2) // memory are warp-accumulated (in chunks of kAccumPixels elements) and summed // up in global memory using atomics. // Requirements: threads per block must be multiple of 32 and <= launch_bounds, -// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockSlices. +// kAccumPixels * 64 >= args.in_rows * args.in_cols * kBlockDepth. template + int kBlockDepth, int kAccumPixels> __global__ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const DepthwiseArgs args, const T* output, const T* input, T* filter) { @@ -1318,39 +1391,41 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( extern __shared__ __align__(sizeof(T)) unsigned char shared_memory[]; T* const shared_data = reinterpret_cast(shared_memory); - const int batches = args.batch; - const int in_rows = args.in_rows; - const int in_cols = blockDim.x; // slower (see b/62280718): args.in_cols; + const int num_batches = args.batch; + const int in_height = args.in_rows; + const int in_width = blockDim.x; // slower (see b/62280718): args.in_cols; const int in_depth = args.in_depth; - const int filter_rows = + const int filter_height = kKnownFilterHeight < 0 ? args.filter_rows : kKnownFilterHeight; - const int filter_cols = + const int filter_width = kKnownFilterWidth < 0 ? args.filter_cols : kKnownFilterWidth; - const int pad_rows = args.pad_rows; - const int pad_cols = args.pad_cols; + const int pad_height = args.pad_rows; + const int pad_width = args.pad_cols; - const int block_rows = blockDim.y; + assert(blockDim.x == args.in_cols); + assert(blockDim.z == kBlockDepth); + const int block_height = blockDim.y; // These values are the same for all threads and could // be precomputed on the CPU. - const int block_pixels = in_cols * block_rows; - const int block_size = block_pixels * kBlockSlices; + const int block_pixels = in_width * block_height; + const int block_size = block_pixels * kBlockDepth; assert((block_size & 31) == 0); - const int in_pixels = in_cols * in_rows; - const int in_increment = in_cols - 1; - const int filter_pixels = filter_rows * filter_cols; - const int tile_cols = in_cols + filter_cols - 1; - const int tile_rows = 2 * block_rows + filter_rows - 1; - const int tile_pixels = tile_cols * tile_rows; - const int tile_size = tile_pixels * kBlockSlices; - const int tile_offset = block_rows * tile_cols; - const int pad_offset = pad_rows * tile_cols + pad_cols; - const int in_slices = in_depth * batches; - const int in_blocks = (in_slices + kBlockSlices - 1) / kBlockSlices; + const int in_pixels = in_width * in_height; + const int in_increment = in_width - 1; + const int filter_pixels = filter_height * filter_width; + const int tile_width = in_width + filter_width - 1; + const int tile_height = 2 * block_height + filter_height - 1; + const int tile_pixels = tile_width * tile_height; + const int tile_size = tile_pixels * kBlockDepth; + const int tile_offset = block_height * tile_width; + const int pad_offset = pad_height * tile_width + pad_width; + const int in_total_depth = in_depth * num_batches; + const int in_blocks = (in_total_depth + kBlockDepth - 1) / kBlockDepth; // The accumulator has a fixed number of pixels that can be reduced by one - // warp. Pixels beyond ceil(in_pixels * kBlockSlices / 64) are never written. - assert(kAccumPixels * 64 >= in_rows * in_cols * kBlockSlices); - const int accum_increment = kAccumPixels * kBlockSlices; + // warp. Pixels beyond ceil(in_pixels * kBlockDepth / 64) are never written. + assert(kAccumPixels * 64 >= in_height * in_width * kBlockDepth); + const int accum_increment = kAccumPixels * kBlockDepth; const int accum_size = filter_pixels * accum_increment; const int thread_col = threadIdx.x; @@ -1358,7 +1433,7 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const int thread_depth = threadIdx.z; // Position in block. - const int thread_pix = thread_row * in_cols + thread_col; + const int thread_pix = thread_row * in_width + thread_col; const int thread_idx = thread_depth * block_pixels + thread_pix; // Initialize tile, in particular the padding and accumulator. @@ -1371,27 +1446,27 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const int tensor_idx = thread_depth * in_pixels + thread_pix; // Position in (padded) shared memory. - const int data_pix = thread_row * tile_cols + thread_col; + const int data_pix = thread_row * tile_width + thread_col; const int data_idx = thread_depth * tile_pixels + data_pix; - // Position in shared memory, offset by pad_rows / pad_cols. + // Position in shared memory, offset by pad_height / pad_width. const int tile_idx = data_idx + pad_offset; - // Position in accumulator (kBlockSlices per warp, depth major). - const int accum_pix = thread_pix / (32 / kBlockSlices); + // Position in accumulator (kBlockDepth per warp, depth major). + const int accum_pix = thread_pix / (32 / kBlockDepth); const int accum_idx = thread_depth * kAccumPixels + accum_pix; - const int max_slice = in_slices - thread_depth; + const int max_channel = in_total_depth - thread_depth; const int accum_offset = tile_size + accum_idx; - const bool skip_second = block_rows + thread_row >= in_rows; + const bool skip_second = block_height + thread_row >= in_height; for (int b = blockIdx.x; b < in_blocks; b += gridDim.x) { - const int slice = b * kBlockSlices; + const int channel = b * kBlockDepth; - const int inout_offset = slice * in_pixels + tensor_idx; - const bool slice_in_range = slice < max_slice; + const int inout_offset = channel * in_pixels + tensor_idx; + const bool channel_in_range = channel < max_channel; - if (slice_in_range) { + if (channel_in_range) { const T* const in_ptr = inout_offset + input; T* const tile_ptr = tile_idx + shared_data; tile_ptr[0] = ldg(in_ptr); @@ -1402,24 +1477,24 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( // Note: the condition to reach this is uniform across the entire block. __syncthreads(); - unsigned active_threads = CudaBallotSync(kCudaWarpAll, slice_in_range); + unsigned active_threads = CudaBallotSync(kCudaWarpAll, channel_in_range); - if (slice_in_range) { + if (channel_in_range) { const T* const out_ptr = inout_offset + output; const T out1 = ldg(out_ptr); const T out2 = skip_second ? T(0) : ldg(block_pixels + out_ptr); int shared_offset = data_idx; T* accum_ptr = accum_offset + shared_data; - UNROLL for (int r = 0; r < filter_rows; ++r) { - UNROLL for (int c = 0; c < filter_cols; ++c) { + UNROLL for (int r = 0; r < filter_height; ++r) { + UNROLL for (int c = 0; c < filter_width; ++c) { const T* const tile_ptr = shared_offset + shared_data; T val = out1 * tile_ptr[0] + out2 * tile_ptr[tile_offset]; // Warp-accumulate pixels of the same depth and write to accumulator. - for (int delta = 16 / kBlockSlices; delta > 0; delta /= 2) { + for (int delta = 16 / kBlockDepth; delta > 0; delta /= 2) { val += CudaShuffleXorSync(active_threads, val, delta); } - if (!(thread_idx & 32 / kBlockSlices - 1)) { - *accum_ptr = val; // kBlockSlices threads per warp. + if (!(thread_idx & 32 / kBlockDepth - 1)) { + *accum_ptr = val; // kBlockDepth threads per warp. } ++shared_offset; accum_ptr += accum_increment; @@ -1434,10 +1509,11 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( const T* const accum_data = tile_size + shared_data; for (int i = thread_idx; i < accum_size; i += block_size) { const int filter_idx = i / kAccumPixels; - const int filter_pix = filter_idx / kBlockSlices; - const int filter_depth = (slice + filter_idx % kBlockSlices) % in_depth; - const int filter_offset = filter_pix * in_depth + filter_depth; - if (filter_depth < in_depth) { + const int filter_pix = filter_idx / kBlockDepth; + const int filter_channel = + (channel + filter_idx % kBlockDepth) % in_depth; + const int filter_offset = filter_pix * in_depth + filter_channel; + if (filter_channel < in_depth) { T val = accum_data[i]; // Warp-accumulate pixels of the same depth from the accumulator. val = WarpSumReduce(val); @@ -1450,109 +1526,119 @@ __launch_bounds__(1024, 2) void DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall( } template + int kBlockDepth, int kAccumPixels> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs& args, const int block_rows, + const GpuDevice& device, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const int tile_cols = args.in_cols + args.filter_cols - 1; - const int tile_rows = block_rows * 2 + args.filter_rows - 1; - const int tile_pixels = tile_rows * tile_cols; + const int tile_width = args.in_cols + args.filter_cols - 1; + const int tile_height = block_height * 2 + args.filter_rows - 1; + const int tile_pixels = tile_height * tile_width; const int filter_pixels = args.filter_rows * args.filter_cols; const int shared_memory_size = - kBlockSlices * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T); - if (shared_memory_size > d.sharedMemPerBlock()) { + kBlockDepth * (tile_pixels + filter_pixels * kAccumPixels) * sizeof(T); + if (shared_memory_size > device.sharedMemPerBlock()) { return false; } dim3 block_dim; + int block_count; void (*kernel)(const DepthwiseArgs, const T*, const T*, T*); - if (data_format == FORMAT_NHWC) { - block_dim = dim3(kBlockSlices, args.in_cols, block_rows); - kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>; - } else if (data_format == FORMAT_NCHW) { - block_dim = dim3(args.in_cols, block_rows, kBlockSlices); - kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, kAccumPixels>; - } else { - assert(false && "Incorrect data format"); - return false; + switch (data_format) { + case FORMAT_NHWC: + block_dim = dim3(kBlockDepth, args.in_cols, block_height); + block_count = + args.batch * DivUp(args.out_depth, kBlockDepth) * kBlockDepth; + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWCSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + break; + case FORMAT_NCHW: + block_dim = dim3(args.in_cols, block_height, kBlockDepth); + block_count = + DivUp(args.batch * args.out_depth, kBlockDepth) * kBlockDepth; + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHWSmall< + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, kAccumPixels>; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return false; } - const int num_out_backprop = - args.batch * args.out_rows * args.out_cols * args.out_depth; - CudaLaunchConfig config = - GetCudaLaunchConfig(num_out_backprop, d, kernel, shared_memory_size, - block_dim.x * block_dim.y * block_dim.z); - kernel<<>>( - args, out_backprop, input, filter_backprop); + const int num_out_backprop = args.out_rows * args.out_cols * block_count; + CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize( + num_out_backprop, device, kernel, shared_memory_size, + block_dim.x * block_dim.y * block_dim.z); + kernel<<>>(args, out_backprop, input, filter_backprop); return true; } template + int kBlockDepth> bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs& args, const int block_rows, + const GpuDevice& device, const DepthwiseArgs& args, const int block_height, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { // Minimize (power of two) kAccumPixels, while satisfying - // kAccumPixels * 32 >= block_rows * in_cols * kBlockSlices. - const int block_pixels = block_rows * args.in_cols * kBlockSlices; + // kAccumPixels * 32 >= block_height * in_width * kBlockDepth. + const int block_pixels = block_height * args.in_cols * kBlockDepth; if (block_pixels > 512) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 32>( - d, args, block_rows, out_backprop, input, filter_backprop, data_format); + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 32>( + device, args, block_height, out_backprop, input, filter_backprop, + data_format); } else if (block_pixels > 256) { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 16>( - d, args, block_rows, out_backprop, input, filter_backprop, data_format); + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 16>( + device, args, block_height, out_backprop, input, filter_backprop, + data_format); } else { return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< - T, kKnownFilterWidth, kKnownFilterHeight, kBlockSlices, 8>( - d, args, block_rows, out_backprop, input, filter_backprop, data_format); + T, kKnownFilterWidth, kKnownFilterHeight, kBlockDepth, 8>( + device, args, block_height, out_backprop, input, filter_backprop, + data_format); } } template bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - const GpuDevice& d, const DepthwiseArgs& args, const T* out_backprop, + const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - // Maximize (power of two) kBlockSlices while keeping a block within 1024 + // Maximize (power of two) kBlockDepth while keeping a block within 1024 // threads (2 pixels per thread). - int block_slices = 8; - int block_rows = (args.in_rows + 1) / 2; + int block_depth = 8; + int block_height = (args.in_rows + 1) / 2; int round_mask = 1; - for (; block_slices > 1; block_slices /= 2) { - // args.in_cols * block_rows * kBlockSlices must be multiple of 32. - for (; block_rows * args.in_cols * block_slices & 31; + for (; block_depth > 1; block_depth /= 2) { + // args.in_cols * block_height * kBlockDepth must be multiple of 32. + for (; block_height * args.in_cols * block_depth & 31; round_mask = round_mask * 2 + 1) { - block_rows = block_rows + round_mask & ~round_mask; + block_height = block_height + round_mask & ~round_mask; } - int block_size = block_rows * args.in_cols * block_slices; + int block_size = block_height * args.in_cols * block_depth; if (block_size <= 1024) { break; } } - if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_rows)) { + if (!CanLaunchDepthwiseConv2dBackpropFilterGPUSmall(args, block_height)) { return false; } - switch (block_slices) { + switch (block_depth) { case 8: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 8>( - d, args, block_rows, out_backprop, input, filter_backprop, + device, args, block_height, out_backprop, input, filter_backprop, data_format); case 4: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 4>( - d, args, block_rows, out_backprop, input, filter_backprop, + device, args, block_height, out_backprop, input, filter_backprop, data_format); case 2: return TryLaunchDepthwiseConv2dBackpropFilterGPUSmall< T, kKnownFilterWidth, kKnownFilterHeight, 2>( - d, args, block_rows, out_backprop, input, filter_backprop, + device, args, block_height, out_backprop, input, filter_backprop, data_format); default: return false; @@ -1561,32 +1647,35 @@ bool TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( template -void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, +void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { void (*kernel)(const DepthwiseArgs, const T*, const T*, T*, int); - if (data_format == FORMAT_NHWC) { - kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else if (data_format == FORMAT_NCHW) { - kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW< - T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; - } else { - assert(false && "Incorrect data format"); - return; + switch (data_format) { + case FORMAT_NHWC: + kernel = DepthwiseConv2dBackpropFilterGPUKernelNHWC< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW: + kernel = DepthwiseConv2dBackpropFilterGPUKernelNCHW< + T, kKnownFilterWidth, kKnownFilterHeight, kKnownDepthMultiplier>; + break; + case FORMAT_NCHW_VECT_C: + LOG(ERROR) << "FORMAT_NCHW_VECT_C is not supported"; + return; } const int num_out_backprop = args.batch * args.out_rows * args.out_cols * args.out_depth; CudaLaunchConfig config = - GetCudaLaunchConfig(num_out_backprop, d, kernel, 0, 0); - kernel<<>>( + GetCudaLaunchConfig(num_out_backprop, device, kernel, 0, 0); + kernel<<>>( args, out_backprop, input, filter_backprop, num_out_backprop); } template -void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, +void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& device, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, @@ -1594,17 +1683,17 @@ void LaunchDepthwiseConv2dBackpropFilterGPU(const GpuDevice& d, if (args.depth_multiplier == 1) { if (TryLaunchDepthwiseConv2dBackpropFilterGPUSmall( - d, args, out_backprop, input, filter_backprop, data_format)) { + device, args, out_backprop, input, filter_backprop, data_format)) { return; } LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } } @@ -1613,7 +1702,7 @@ template void LaunchDepthwiseConvBackpropFilterOp::operator()( OpKernelContext* ctx, const DepthwiseArgs& args, const T* out_backprop, const T* input, T* filter_backprop, TensorFormat data_format) { - const GpuDevice& d = ctx->eigen_device(); + const GpuDevice& device = ctx->eigen_device(); auto stream = ctx->op_device_context()->stream(); // Initialize the results to 0. @@ -1625,10 +1714,10 @@ void LaunchDepthwiseConvBackpropFilterOp::operator()( if (args.filter_rows == 3 && args.filter_cols == 3) { LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } else { LaunchDepthwiseConv2dBackpropFilterGPU( - d, args, out_backprop, input, filter_backprop, data_format); + device, args, out_backprop, input, filter_backprop, data_format); } OP_REQUIRES(ctx, stream->ok(), errors::Internal("Launch of gpu kernel for " diff --git a/tensorflow/core/kernels/function_ops.cc b/tensorflow/core/kernels/function_ops.cc index 9d4bc35ba890c251b0800f266e7845e411e7a835..a094ebe5e2d1d78ec8f5514dca7b7ebeec4e6b57 100644 --- a/tensorflow/core/kernels/function_ops.cc +++ b/tensorflow/core/kernels/function_ops.cc @@ -32,7 +32,9 @@ limitations under the License. namespace tensorflow { -static const char* const kGradientOp = "SymbolicGradient"; +static const char* const kArgOp = FunctionLibraryDefinition::kArgOp; +static const char* const kRetOp = FunctionLibraryDefinition::kRetOp; +static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp; class ArgOp : public OpKernel { public: @@ -89,26 +91,25 @@ class RetvalOp : public OpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp); }; -REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Arg").Device(DEVICE_CPU), ArgOp); -REGISTER_SYSTEM_KERNEL_BUILDER(Name("_Retval").Device(DEVICE_CPU), RetvalOp); +REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp); +REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp); #if TENSORFLOW_USE_SYCL #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ - Name("_Arg").Device(DEVICE_SYCL).TypeConstraint("T"), ArgOp); + Name(kArgOp).Device(DEVICE_SYCL).TypeConstraint("T"), ArgOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) -TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") +TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp) .Device(DEVICE_SYCL) .HostMemory("output") .TypeConstraint("T"), ArgOp); #undef REGISTER -#define REGISTER(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("_Retval").Device(DEVICE_SYCL).TypeConstraint("T"), \ - RetvalOp); +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER( \ + Name(kRetOp).Device(DEVICE_SYCL).TypeConstraint("T"), RetvalOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) -TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") +TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp) .Device(DEVICE_SYCL) .HostMemory("input") .TypeConstraint("T"), @@ -118,16 +119,16 @@ TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ - Name("_Arg").Device(DEVICE_GPU).TypeConstraint("T"), ArgOp); + Name(kArgOp).Device(DEVICE_GPU).TypeConstraint("T"), ArgOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) -TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Arg") +TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp) .Device(DEVICE_GPU) .HostMemory("output") .TypeConstraint("T"), ArgOp); #undef REGISTER -REGISTER_KERNEL_BUILDER(Name("_Arg") +REGISTER_KERNEL_BUILDER(Name(kArgOp) .Device(DEVICE_GPU) .HostMemory("output") .TypeConstraint("T"), @@ -135,9 +136,9 @@ REGISTER_KERNEL_BUILDER(Name("_Arg") #define REGISTER(type) \ REGISTER_KERNEL_BUILDER( \ - Name("_Retval").Device(DEVICE_GPU).TypeConstraint("T"), RetvalOp); + Name(kRetOp).Device(DEVICE_GPU).TypeConstraint("T"), RetvalOp); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER) -TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name("_Retval") +TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp) .Device(DEVICE_GPU) .HostMemory("input") .TypeConstraint("T"), @@ -287,7 +288,8 @@ REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_SYCL), class RemoteCallOp : public AsyncOpKernel { public: explicit RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_)); } ~RemoteCallOp() override {} diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD index 41af950d7deca5a9ce1e2ca6496ccf40fd72dd87..9a7eca03ce276d26321f01f80ad7f1a0a254e4db 100644 --- a/tensorflow/core/kernels/fuzzing/BUILD +++ b/tensorflow/core/kernels/fuzzing/BUILD @@ -43,6 +43,8 @@ tf_ops_fuzz_target_lib("decode_base64") tf_ops_fuzz_target_lib("encode_jpeg") +tf_ops_fuzz_target_lib("decode_bmp") + tf_ops_fuzz_target_lib("decode_png") tf_ops_fuzz_target_lib("decode_jpeg") diff --git a/tensorflow/core/kernels/fuzzing/decode_bmp_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_bmp_fuzz.cc new file mode 100644 index 0000000000000000000000000000000000000000..01c56ac6f67223108768c3e6922d7f193f93d52a --- /dev/null +++ b/tensorflow/core/kernels/fuzzing/decode_bmp_fuzz.cc @@ -0,0 +1,29 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" + +namespace tensorflow { +namespace fuzzing { + +class FuzzDecodeBmp : public FuzzStringInputOp { + SINGLE_INPUT_OP_BUILDER(DT_STRING, DecodeBmp); +}; + +STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeBmp); + +} // end namespace fuzzing +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index bacf3e77408a12a8a95bf7e7ab8f3a580e675675..6b6a14e9a7383b0a0720782acf69e0896df2444e 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -90,4 +90,23 @@ class PrintOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp); +class TimestampOp : public OpKernel { + public: + explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + TensorShape output_shape; // Default shape is 0 dim, 1 element + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, output_shape, &output_tensor)); + + auto output_scalar = output_tensor->scalar(); + double now_us = static_cast(Env::Default()->NowMicros()); + double now_s = now_us / 1000000; + output_scalar() = now_s; + } +}; + +REGISTER_KERNEL_BUILDER(Name("Timestamp").Device(DEVICE_CPU), TimestampOp); + } // end namespace tensorflow diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc index 9cf669a7efc973a7be4f3139b2180d4e3b07797b..5e6958f364dbbfd6ff6cf112a6cef544202ee955 100644 --- a/tensorflow/core/kernels/logging_ops_test.cc +++ b/tensorflow/core/kernels/logging_ops_test.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/tensor.h" @@ -96,5 +99,27 @@ TEST_F(PrintingGraphTest, FirstNSuccess) { test::ExpectTensorEqual(expected, *GetOutput(0)); } +class TimestampTest : public OpsTestBase { + protected: + Status Init() { + TF_CHECK_OK(NodeDefBuilder("op", "Timestamp").Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(TimestampTest, WaitAtLeast) { + TF_ASSERT_OK(Init()); + TF_ASSERT_OK(RunOpKernel()); + double ts1 = *((*GetOutput(0)).flat().data()); + + // wait 1 second + std::this_thread::sleep_for(std::chrono::seconds(1)); + + TF_ASSERT_OK(RunOpKernel()); + double ts2 = *((*GetOutput(0)).flat().data()); + + EXPECT_LE(1.0, ts2 - ts1); +} + } // end namespace } // end namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index d9713075be6e20b77ea681a0e71baa21b7b9eea9..723b445a7568775a13b89c9fbf0e7dc70c4b8b8c 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -29,7 +29,6 @@ limitations under the License. #include #include "mkl_cblas.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -41,9 +40,6 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" -#define MKL_Complex8 tensorflow::complex64 -#define MKL_Complex16 tensorflow::complex128 - namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; @@ -180,16 +176,16 @@ class BatchMatMulMkl : public OpKernel { void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, const MKL_INT *M_Array, const MKL_INT *N_Array, const MKL_INT *K_Array, - const MKL_Complex8 **A_Array, const MKL_INT *lda_Array, - const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array, - MKL_Complex8 **C_Array, const MKL_INT *ldc_Array, + const complex64 **A_Array, const MKL_INT *lda_Array, + const complex64 **B_Array, const MKL_INT *ldb_Array, + complex64 **C_Array, const MKL_INT *ldc_Array, const MKL_INT group_count, const MKL_INT *group_size) { std::vector TransA_array( group_size[0], TransA ? CblasConjTrans : CblasNoTrans); std::vector TransB_array( group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); + std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); + std::vector beta_Array(group_size[0], {0.0f, 0.0f}); cblas_cgemm_batch( Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, static_cast(&alpha_Array[0]), @@ -202,18 +198,16 @@ class BatchMatMulMkl : public OpKernel { void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA, const bool TransB, const MKL_INT *M_Array, const MKL_INT *N_Array, const MKL_INT *K_Array, - const MKL_Complex16 **A_Array, - const MKL_INT *lda_Array, - const MKL_Complex16 **B_Array, - const MKL_INT *ldb_Array, MKL_Complex16 **C_Array, - const MKL_INT *ldc_Array, const MKL_INT group_count, - const MKL_INT *group_size) { + const complex128 **A_Array, const MKL_INT *lda_Array, + const complex128 **B_Array, const MKL_INT *ldb_Array, + complex128 **C_Array, const MKL_INT *ldc_Array, + const MKL_INT group_count, const MKL_INT *group_size) { std::vector TransA_array( group_size[0], TransA ? CblasConjTrans : CblasNoTrans); std::vector TransB_array( group_size[0], TransB ? CblasConjTrans : CblasNoTrans); - std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); - std::vector beta_Array(group_size[0], {0.0f, 0.0f}); + std::vector alpha_Array(group_size[0], {1.0f, 0.0f}); + std::vector beta_Array(group_size[0], {0.0f, 0.0f}); cblas_zgemm_batch( Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array, static_cast(&alpha_Array[0]), diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 8313224d7fe3e2d307d3642ced5b277b95c85cdb..eccdece5e3db36d0f144da98d358d6b7d0830499 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -1110,19 +1110,12 @@ class MklFusedBatchNormGradOp : public OpKernel { return; } - if (dnn_shape_src.IsMklTensor()) - depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); - else - ExtractParams(context); - - memory::format format_m; if (dnn_shape_src.IsMklTensor()) { - if (dnn_shape_src.IsTensorInNCHWFormat()) - format_m = memory::format::nchw; - else - format_m = memory::format::nhwc; + depth_ = dnn_shape_src.DimSize(MklDnnDims::Dim_C); + } else if (dnn_shape_diff_dst.IsMklTensor()) { + depth_ = dnn_shape_diff_dst.DimSize(MklDnnDims::Dim_C); } else { - format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); + ExtractParams(context); } MklDnnData src(&cpu_engine); @@ -1146,20 +1139,20 @@ class MklFusedBatchNormGradOp : public OpKernel { diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); - // set src and diff_dst primitives + // set src and diff_dst primitives according to input layout memory::desc src_md({}, memory::data_undef, memory::format_undef); memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); - if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { - if (dnn_shape_src.IsMklTensor()) { - src_md = dnn_shape_src.GetMklLayout(); - diff_dst_md = src_md; - } else { - diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); - src_md = diff_dst_md; - } + if (dnn_shape_src.IsMklTensor()) { + src_md = dnn_shape_src.GetMklLayout(); } else { - src_md = memory::desc(src_dims, MklDnnType(), format_m); - diff_dst_md = src_md; + src_md = memory::desc(src_dims, MklDnnType(), + TFDataFormatToMklDnnDataFormat(tensor_format_)); + } + if (dnn_shape_diff_dst.IsMklTensor()) { + diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); + } else { + diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), + TFDataFormatToMklDnnDataFormat(tensor_format_)); } src.SetUsrMem(src_md, &src_tensor); diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); @@ -1211,28 +1204,64 @@ class MklFusedBatchNormGradOp : public OpKernel { // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - if (dnn_shape_src.IsMklTensor()) { + + // MKL-DNN's BN primitive not provide API to fetch internal format + // set common_md as OpMem + // src and diff_dst will reorder to common_md + // diff_src will set as common_md + memory::desc common_md({}, memory::data_undef, memory::format_undef); + if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { + if (dnn_shape_src.IsMklTensor()) { + common_md = dnn_shape_src.GetMklLayout(); + } else { + common_md = dnn_shape_diff_dst.GetMklLayout(); + } + } else { + common_md = memory::desc(src_dims, MklDnnType(), + TFDataFormatToMklDnnDataFormat(tensor_format_)); + } + // if any of src and diff_dst as mkl layout, + // then we set diff_src as mkl layout + if (dnn_shape_src.IsMklTensor() || + dnn_shape_diff_dst.IsMklTensor()) { dnn_shape_diff_src.SetMklTensor(true); - auto diff_src_pd = bnrm_fwd_pd.dst_primitive_desc(); + // set diff_src's mkl layout as common_md + auto diff_src_pd = memory::primitive_desc(common_md, cpu_engine); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); - dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), src_dims, - format_m); - dnn_shape_diff_src.SetTfDimOrder(dnn_shape_src.GetDimension(), - tensor_format_); + if (dnn_shape_src.IsMklTensor()) { + dnn_shape_diff_src.SetTfLayout( + dnn_shape_src.GetDimension(), + src_dims, + dnn_shape_src.GetTfDataFormat()); + dnn_shape_diff_src.SetTfDimOrder( + dnn_shape_src.GetDimension(), + tensor_format_); + } else { + dnn_shape_diff_src.SetTfLayout( + dnn_shape_diff_dst.GetDimension(), + src_dims, + dnn_shape_diff_dst.GetTfDataFormat()); + dnn_shape_diff_src.SetTfDimOrder( + dnn_shape_diff_dst.GetDimension(), + tensor_format_); + } tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); + // both src and diff_dst are TensorFlow layout, + // so it is OK to get TensorFlow shape. tf_shape_diff_src = src_tensor.shape(); } AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, tf_shape_diff_src, dnn_shape_diff_src); - diff_src.SetUsrMem(src_md, diff_src_tensor); + // set diff_src + diff_src.SetUsrMem(common_md, diff_src_tensor); prop_kind pk = prop_kind::backward; auto bnrm_bwd_desc = batch_normalization_backward::desc( - pk, diff_src.GetUsrMemDesc(), src.GetUsrMemDesc(), epsilon_, + pk, common_md, common_md, epsilon_, /* for inference, specify use_global_stats 1. on fwd prop, use mean and variance provided as inputs @@ -1245,11 +1274,16 @@ class MklFusedBatchNormGradOp : public OpKernel { auto bnrm_bwd_pd = batch_normalization_backward::primitive_desc( bnrm_bwd_desc, cpu_engine, bnrm_fwd_pd); + std::vector net; + src.CheckReorderToOpMem(memory::primitive_desc(common_md, + cpu_engine), &net); + diff_dst.CheckReorderToOpMem(memory::primitive_desc(common_md, + cpu_engine), &net); + auto bnrm_bwd_op = batch_normalization_backward( bnrm_bwd_pd, src.GetOpMem(), mean.GetOpMem(), variance.GetOpMem(), diff_dst.GetOpMem(), weights_m, diff_src.GetOpMem(), diff_weights_m); - std::vector net; net.push_back(bnrm_bwd_op); stream(stream::kind::eager).submit(net).wait(); diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index 5a8799ae93c1bb3a53f19036c7bb13874a80d7fa..e9a2376b545fcec97e1ced5c592351203abadd69 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -145,8 +145,8 @@ class MklInputConversionOp : public OpKernel { const MklShape* mkl_shape; const Tensor* tf_tensor; MklShape* tf_mkl_shape; - uint mkl_tensor_index; - uint tf_tensor_index; + uint32 mkl_tensor_index; + uint32 tf_tensor_index; if (input_shape_0.IsMklTensor() && !input_shape_1.IsMklTensor()) { mkl_tensor = &input_tensor_0; mkl_shape = &input_shape_0; diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 47598f443f76f17a6c0b4005327a4e7d00a6beba..dfa6cecc9bdc231ebf35e587183b5f84b17489e0 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -170,32 +170,32 @@ class MklMatMulOp : public OpKernel { // Matrix-Matrix Multiplication with Complex64 (std::complex) tensors. // For detailed info about parameters, look at FP32 function description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex* a, const int lda, - const std::complex* b, const int ldb, - std::complex* c, int const ldc) { + const int k, const complex64* a, const int lda, + const complex64* b, const int ldb, complex64* c, + int const ldc) { const MKL_Complex8 alpha = {1.0f, 0.0f}; const MKL_Complex8 beta = {0.0f, 0.0f}; cblas_cgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast(&alpha), static_cast(a), - lda, static_cast(b), ldb, - static_cast(&beta), static_cast(c), ldc); + transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, + reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, &beta, + reinterpret_cast(c), ldc); } // Matrix-Matrix Multiplication with Complex128 (std::complex) // tensors. For detailed info about parameters, look at FP32 function // description. void MklBlasGemm(bool transa, bool transb, const int m, const int n, - const int k, const std::complex* a, const int lda, - const std::complex* b, const int ldb, - std::complex* c, const int ldc) { + const int k, const complex128* a, const int lda, + const complex128* b, const int ldb, complex128* c, + const int ldc) { const MKL_Complex16 alpha = {1.0, 0.0}; const MKL_Complex16 beta = {0.0, 0.0}; cblas_zgemm(CblasRowMajor, transa ? CblasTrans : CblasNoTrans, - transb ? CblasTrans : CblasNoTrans, m, n, k, - static_cast(&alpha), static_cast(a), - lda, static_cast(b), ldb, - static_cast(&beta), static_cast(c), ldc); + transb ? CblasTrans : CblasNoTrans, m, n, k, &alpha, + reinterpret_cast(a), lda, + reinterpret_cast(b), ldb, &beta, + reinterpret_cast(c), ldc); } }; diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 51db3991e2a24f087771f571cd91fc9fbb26040b..6c873af56654c3c4fc1bbc56153cfaf435e54ca2 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -368,8 +368,11 @@ void MklReluGradOp::Compute(OpKernelContext* context) { mkl_context.MklCleanup(); } + + #else // INTEL_MKL_ML + template class MklReluOpBase : public OpKernel { public: @@ -579,17 +582,26 @@ class MklReluGradOpBase : public OpKernel { // allocate diff_src tensor MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; - if (dnn_shape_src.IsMklTensor()) { + if (dnn_shape_src.IsMklTensor() || + dnn_shape_diff_dst.IsMklTensor()) { dnn_shape_diff_src.SetMklTensor(true); auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc(); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); - dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), - dnn_shape_src.GetSizesAsMklDnnDims(), - dnn_shape_src.GetTfDataFormat()); + if (dnn_shape_src.IsMklTensor()) { + dnn_shape_diff_src.SetTfLayout(dnn_shape_src.GetDimension(), + dnn_shape_src.GetSizesAsMklDnnDims(), + dnn_shape_src.GetTfDataFormat()); + } else { + dnn_shape_diff_src.SetTfLayout(dnn_shape_diff_dst.GetDimension(), + dnn_shape_diff_dst.GetSizesAsMklDnnDims(), + dnn_shape_diff_dst.GetTfDataFormat()); + } tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); } else { dnn_shape_diff_src.SetMklTensor(false); + // both src and diff_dst are TensorFlow layout, + // so it is ok to get TensorFlow shape. tf_shape_diff_src = src_tensor.shape(); } AllocateOutputSetMklShape(context, diff_src_index, &diff_src_tensor, diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h index 5fafa14b5dbf49d0c9902af4e38653b48d1f179b..ddea9e281b2fbcf0e061fd2bf2758984833a3727 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.h +++ b/tensorflow/core/kernels/mkl_tfconv_op.h @@ -128,7 +128,7 @@ class MklToTfOp : public OpKernel { #else static void ConvertMklToTf(OpKernel* op_kernel, OpKernelContext* context, string data_format_str, DataType op_data_type, - bool has_avx512f, uint input_number) { + bool has_avx512f, uint32 input_number) { // Check that input tensor is in MKL format. const Tensor& input_tensor = MklGetInput(context, input_number); MklShape input_shape; diff --git a/tensorflow/core/kernels/mkl_transpose_op.cc b/tensorflow/core/kernels/mkl_transpose_op.cc index 764d4c9400e5751de29b9651eebc1328fdd09d59..3f07b317c4d915fd7d304dbbab966837da64757a 100644 --- a/tensorflow/core/kernels/mkl_transpose_op.cc +++ b/tensorflow/core/kernels/mkl_transpose_op.cc @@ -18,9 +18,6 @@ limitations under the License. #ifdef INTEL_MKL #define EIGEN_USE_THREADS -#include "tensorflow/core/framework/numeric_types.h" -#define MKL_Complex8 tensorflow::complex64 -#define MKL_Complex16 tensorflow::complex128 #include "mkl_trans.h" #include "tensorflow/core/kernels/transpose_functor.h" #include "tensorflow/core/kernels/transpose_op.h" @@ -62,10 +59,37 @@ Status MKLTranspose2D(const char trans, const Tensor& in, Tensor* out); INSTANTIATE(float, s) INSTANTIATE(double, d) -INSTANTIATE(complex64, c) -INSTANTIATE(complex128, z) + #undef INSTANTIATE +template <> +Status MKLTranspose2D(const char trans, const Tensor& in, + Tensor* out) { + const MKL_Complex8 alpha = {1.0f, 0.0f}; + mkl_comatcopy( + 'R', trans, in.dim_size(0), in.dim_size(1), alpha, + reinterpret_cast(in.flat().data()), + in.dim_size(1), + reinterpret_cast( + const_cast(out->flat().data())), + in.dim_size(0)); + return Status::OK(); +} + +template <> +Status MKLTranspose2D(const char trans, const Tensor& in, + Tensor* out) { + const MKL_Complex16 alpha = {1.0, 0.0}; + mkl_zomatcopy( + 'R', trans, in.dim_size(0), in.dim_size(1), alpha, + reinterpret_cast(in.flat().data()), + in.dim_size(1), + reinterpret_cast( + const_cast(out->flat().data())), + in.dim_size(0)); + return Status::OK(); +} + static const char kMKLTranspose = 'T'; static const char kMKLConjugateTranspose = 'C'; diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..b8b1fc7679d758f2855af33618620e00f1bbb7e1 --- /dev/null +++ b/tensorflow/core/kernels/mutex_ops.cc @@ -0,0 +1,249 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/variant.h" +#include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +class Mutex : public ResourceBase { + public: + explicit Mutex(OpKernelContext* c, const string& name) + : locked_(false), + thread_pool_(new thread::ThreadPool( + c->env(), ThreadOptions(), + strings::StrCat("mutex_lock_thread_", SanitizeThreadSuffix(name)), + 1 /* num_threads */, false /* low_latency_hint */)), + name_(name) { + VLOG(2) << "Creating mutex with name " << name << ": " << this; + } + + string DebugString() override { return strings::StrCat("Mutex ", name_); } + + class LockReleaser { + public: + explicit LockReleaser(Mutex* mutex) : mutex_(mutex) {} + + LockReleaser(const LockReleaser&) = delete; + LockReleaser& operator=(const LockReleaser&) = delete; + + virtual ~LockReleaser() { + VLOG(3) << "Destroying LockReleaser " << this << " for mutex: " << mutex_; + if (mutex_) { + mutex_lock lock(mutex_->mu_); + mutex_->locked_ = false; + mutex_->cv_.notify_all(); + VLOG(3) << "Destroying LockReleaser " << this + << ": sent notifications."; + } + } + + private: + Mutex* mutex_; + }; + + struct SharedLockReleaser { + std::shared_ptr shared_lock; + + explicit SharedLockReleaser(std::shared_ptr&& lock) + : shared_lock(std::forward(lock)) { + VLOG(3) << "Creating shared_ptr of " << shared_lock.get() + << " count is: " << shared_lock.use_count(); + } + + SharedLockReleaser(SharedLockReleaser&& rhs) + : shared_lock(std::move(rhs.shared_lock)) { + VLOG(3) << "Moving SharedLockReleaser of " << shared_lock.get() + << " count is: " << shared_lock.use_count(); + } + + SharedLockReleaser(const SharedLockReleaser& rhs) + : shared_lock(rhs.shared_lock) { + VLOG(3) << "Copying SharedLockReleaser of " << shared_lock.get() + << " count is: " << shared_lock.use_count(); + } + + ~SharedLockReleaser() { + VLOG(3) << "Destroying SharedLockReleaser of " << shared_lock.get() + << " count is: " << shared_lock.use_count(); + } + + void Encode(VariantTensorData*) const { + // Not supported. + } + + bool Decode(const VariantTensorData&) { + return false; // Not supported. + } + }; + + void AcquireAsync( + OpKernelContext* c, + std::function fn) { + CancellationManager* cm = c->cancellation_manager(); + CancellationToken token{}; + bool* cancelled = nullptr; + if (cm) { + cancelled = new bool(false); // GUARDED_BY(mu_); + token = cm->get_cancellation_token(); + const bool already_cancelled = + !cm->RegisterCallback(token, [this, cancelled]() { + mutex_lock lock(mu_); + *cancelled = true; + cv_.notify_all(); + }); + if (already_cancelled) { + delete cancelled; + fn(errors::Cancelled("Lock acquisition cancelled."), + SharedLockReleaser{nullptr}); + return; + } + } + thread_pool_->Schedule(std::bind( + [this, c, cm, cancelled, + token](std::function + fn_) { + bool local_locked; + { + mutex_lock lock(mu_); + while (locked_ && !(cancelled && *cancelled)) { + cv_.wait(lock); + } + local_locked = locked_ = !(cancelled && *cancelled); + } + if (cm) { + cm->DeregisterCallback(token); + delete cancelled; + } + if (local_locked) { // Not cancelled. + fn_(Status::OK(), + SharedLockReleaser{std::make_shared(this)}); + } else { + fn_(errors::Cancelled("Lock acqusition cancelled."), + SharedLockReleaser{nullptr}); + } + }, + std::move(fn))); + } + + private: + mutex mu_; + condition_variable cv_ GUARDED_BY(mu_); + bool locked_ GUARDED_BY(mu_); + std::unique_ptr thread_pool_; + string name_; +}; + +} // namespace + +class MutexLockOp : public AsyncOpKernel { + public: + explicit MutexLockOp(OpKernelConstruction* c) : AsyncOpKernel(c) {} + + public: + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + Mutex* mutex = nullptr; + OP_REQUIRES_OK_ASYNC( + c, + LookupOrCreateResource(c, HandleFromInput(c, 0), &mutex, + [this, c](Mutex** ptr) { + *ptr = new Mutex( + c, HandleFromInput(c, 0).name()); + return Status::OK(); + }), + done); + + Tensor* variant; + OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, TensorShape({}), &variant), + done); + + mutex->AcquireAsync( + c, std::bind( + [this, c, variant, mutex](DoneCallback done_, + // End of bound arguments. + const Status& s, + Mutex::SharedLockReleaser&& lock) { + core::ScopedUnref unref(mutex); + VLOG(2) << "Finished locking mutex " << mutex + << " with lock: " << lock.shared_lock.get() + << " status: " << s.ToString(); + if (s.ok()) { + variant->scalar()() = std::move(lock); + } else { + c->SetStatus(s); + } + done_(); + }, + std::move(done), std::placeholders::_1, std::placeholders::_2)); + } +}; + +class ConsumeMutexLockOp : public OpKernel { + public: + explicit ConsumeMutexLockOp(OpKernelConstruction* context) + : OpKernel(context) {} + + void Compute(OpKernelContext* c) override { + VLOG(2) << "Executing ConsumeMutexLockOp"; + const Tensor& lock_t = c->input(0); + OP_REQUIRES( + c, lock_t.dims() == 0, + errors::InvalidArgument("Expected input to be a scalar, saw shape: ", + lock_t.shape().DebugString())); + OP_REQUIRES( + c, lock_t.dtype() == DT_VARIANT, + errors::InvalidArgument("Expected input to be a variant, saw type: ", + DataTypeString(lock_t.dtype()))); + const auto* lock = + lock_t.scalar()().get(); + OP_REQUIRES(c, lock, + errors::InvalidArgument( + "Expected input to contain a SharedLockReleaser " + "object, but saw variant: '", + lock_t.scalar()().DebugString(), "'")); + const int use_count = lock->shared_lock.use_count(); + OP_REQUIRES( + c, use_count == 1, + errors::InvalidArgument("Expected use count of lock to be 1, but saw: ", + use_count)); + } + + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp); + +REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_CPU), + ResourceHandleOp); + +REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU), + ConsumeMutexLockOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 5d28b87e6bb8c0f51653fc005a2f62734a44d321..903b898d0ac850e88c216cb1cc266cdb29fb4ca7 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -105,7 +105,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, } const int output_size = std::min(max_output_size.scalar()(), num_boxes); - typename TTypes::ConstTensor boxes_data = boxes.tensor(); + TTypes::ConstTensor boxes_data = boxes.tensor(); std::vector scores_data(num_boxes); std::copy_n(scores.flat().data(), num_boxes, scores_data.begin()); @@ -138,8 +138,7 @@ void DoNonMaxSuppressionOp(OpKernelContext* context, const Tensor& boxes, Tensor* output = nullptr; TensorShape output_shape({static_cast(selected.size())}); OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); - typename TTypes::Tensor selected_indices_data = - output->tensor(); + TTypes::Tensor selected_indices_data = output->tensor(); std::copy_n(selected.begin(), selected.size(), selected_indices_data.data()); } diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc index ddfeb1bb7903e4dd66f557df7702c083a6b62899..661d47d925d1143d88b88d73b4ca51c654b43498 100644 --- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc +++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc @@ -29,7 +29,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/util/cuda_kernel_helper.h" -#ifdef COMPILER_MSVC +#if defined(_MSC_VER) && !defined(__clang__) // msvc does not support unroll. One could try the loop pragma but we need to // take a closer look if this generates better code in this case. For now let // the compiler take care of it. diff --git a/tensorflow/core/kernels/quantized_resize_bilinear_op.cc b/tensorflow/core/kernels/quantized_resize_bilinear_op.cc index fb2faede2f9f9e56728ad3ab354440eabd488818..9a1dcd0d496e45977704f49c10fba1048effc943 100644 --- a/tensorflow/core/kernels/quantized_resize_bilinear_op.cc +++ b/tensorflow/core/kernels/quantized_resize_bilinear_op.cc @@ -697,8 +697,8 @@ class QuantizedResizeBilinearOp : public OpKernel { // Return if the output is empty. if (st.output->NumElements() == 0) return; - typename TTypes::ConstTensor image_data = input.tensor(); - typename TTypes::Tensor output_data = st.output->tensor(); + typename TTypes::ConstTensor image_data(input.tensor()); + typename TTypes::Tensor output_data(st.output->tensor()); ResizeBilinear(image_data, st.height_scale, st.width_scale, in_min, in_max, &output_data); diff --git a/tensorflow/core/kernels/random_crop_op.cc b/tensorflow/core/kernels/random_crop_op.cc index 554909760aa8a6bebe7e2988cd995f9373e1cc33..b89bda4769dd42590006f803ea45dbb7573bc332 100644 --- a/tensorflow/core/kernels/random_crop_op.cc +++ b/tensorflow/core/kernels/random_crop_op.cc @@ -92,8 +92,8 @@ class RandomCropOp : public OpKernel { // TODO(shlens): Do this more efficiently with memcpy once padding is // available for smaller images. - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + typename TTypes::Tensor output_data(output->tensor()); for (int y = 0; y < target_height; ++y) { for (int x = 0; x < target_width; ++x) { diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 15ae4c1fc53b2b9bfe1d6085d2ecbc3659705b47..9237fa51d885c633675146191dc384dd87d8ab22 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -280,8 +280,8 @@ __global__ void ColumnReduceMax16ColumnsKernel( const int rows_in_this_warp = min(rows_per_warp, num_rows - start_row_warp); // not the most efficient way to do this sum for (int i = 1; i < rows_in_this_warp; ++i) { - value_type tmp = - cub::ShuffleIndex(sum, threadIdx.x + i * num_cols, 32, 0xffffffff); + value_type tmp = cub::ShuffleIndex<32, value_type>( + sum, static_cast(threadIdx.x + i * num_cols), 0xffffffff); if (lane < num_cols) sum = op(sum, tmp); } diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc index ada50dfb70de447d9be9f735c6b973a25933cfa5..98b8a0df282a21f6711cc8926762f7bbb4ef52b0 100644 --- a/tensorflow/core/kernels/resize_area_op.cc +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -149,7 +149,7 @@ class ResizeAreaOp : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_data = input.tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); // Precompute values used when iterating over x coordinates within a row. // Note that it may be useful to cache x_interps for a given @@ -190,8 +190,7 @@ class ResizeAreaOp : public OpKernel { void ComputeLoop(const ImageResizerState& st, const std::vector& x_interps, typename TTypes::ConstTensor input_data) { - typename TTypes::Tensor output_data = - st.output->tensor(); + TTypes::Tensor output_data = st.output->tensor(); // When using this algorithm for downsizing, the target pixel value is the // weighted average of all the source pixels. The weight is determined by diff --git a/tensorflow/core/kernels/resize_bicubic_op.cc b/tensorflow/core/kernels/resize_bicubic_op.cc index 86e61bbcefc1ad2b103552101c17a05c3c3ede6e..65014b6c44eb2e5b0adb528c3ce08f01c21e4f26 100644 --- a/tensorflow/core/kernels/resize_bicubic_op.cc +++ b/tensorflow/core/kernels/resize_bicubic_op.cc @@ -480,9 +480,8 @@ class ResizeBicubicOp : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = - st.output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + TTypes::Tensor output_data = st.output->tensor(); interpolate_with_caching(input_data, st, output_data); } @@ -510,9 +509,8 @@ class ResizeBicubicOpGrad : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_grad = - input.tensor(); - typename TTypes::Tensor output_grad = st.output->tensor(); + TTypes::ConstTensor input_grad = input.tensor(); + typename TTypes::Tensor output_grad(st.output->tensor()); ResizeBicubicGrad(input_grad, st, output_grad); } diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc index d9cb993a4b296d053ec5f9f8a44955728dc5c949..dde59e8e741aca2c715aeb9d548979200af8789b 100644 --- a/tensorflow/core/kernels/resize_bilinear_op.cc +++ b/tensorflow/core/kernels/resize_bilinear_op.cc @@ -51,9 +51,8 @@ class ResizeBilinearOp : public OpKernel { // Return if the output is empty. if (st.output->NumElements() == 0) return; - typename TTypes::ConstTensor image_data = input.tensor(); - typename TTypes::Tensor output_data = - st.output->tensor(); + typename TTypes::ConstTensor image_data(input.tensor()); + TTypes::Tensor output_data = st.output->tensor(); functor::ResizeBilinear()(context->eigen_device(), image_data, st.height_scale, @@ -258,9 +257,8 @@ class ResizeBilinearOpGrad : public OpKernel { if (!context->status().ok()) return; - typename TTypes::ConstTensor input_grad = - input.tensor(); - typename TTypes::Tensor output_grad = st.output->tensor(); + TTypes::ConstTensor input_grad = input.tensor(); + typename TTypes::Tensor output_grad(st.output->tensor()); functor::ResizeBilinearGrad()(context->eigen_device(), input_grad, st.height_scale, diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc index bfd29b7ec89e6a2d0e2757db31b707be70d12c1d..8ec526c2b25dc870e150d2afbfb9af6fbd1d778d 100644 --- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc +++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc @@ -56,8 +56,8 @@ class ResizeNearestNeighborOp : public OpKernel { // Return if the output is empty. if (st.output->NumElements() == 0) return; - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = st.output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + typename TTypes::Tensor output_data(st.output->tensor()); bool status; if (align_corners_) { @@ -162,8 +162,8 @@ class ResizeNearestNeighborOpGrad : public OpKernel { // Return if the output is empty. if (output->NumElements() == 0) return; - typename TTypes::ConstTensor input_data = input.tensor(); - typename TTypes::Tensor output_data = output->tensor(); + typename TTypes::ConstTensor input_data(input.tensor()); + typename TTypes::Tensor output_data(output->tensor()); const float height_scale = CalculateResizeScale(out_height, in_height, align_corners_); diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 5b4aad3cdd83905716df0fd67cec4817e04a1ee1..2041fb90946860c5164da3cb448ff81d9f654e54 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -130,6 +130,7 @@ REGISTER_KERNEL_BUILDER( ResourceHandleOp) TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -252,6 +253,7 @@ class AssignVariableOp : public OpKernel { std::unique_ptr input_alias = context->forward_input(1, dtype_, value.shape(), DEVICE_MEMORY, attr); mutex_lock ml(*variable->mu()); + variable->is_initialized = true; if (input_alias) { *variable->tensor() = *input_alias; return; @@ -362,7 +364,7 @@ class AssignVariableOp : public OpKernel { DataTypeString(DT_VARIANT))); mutex_lock ml(*variable->mu()); - + variable->is_initialized = true; *variable->tensor() = Tensor(DT_VARIANT, value.shape()); const auto elements_in = value.flat(); auto elements_out = variable->tensor()->flat(); @@ -398,6 +400,7 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); AssignVariableOp); TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA @@ -456,11 +459,33 @@ TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); AssignUpdateVariableOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA +class VarIsInitializedOp : public OpKernel { + public: + explicit VarIsInitializedOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* context) override { + Tensor* output = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), &output)); + auto output_tensor = output->tensor(); + Var* variable = nullptr; + Status s = LookupResource(context, HandleFromInput(context, 0), &variable); + if (!s.ok()) { + output_tensor() = false; + return; + } + core::ScopedUnref su(variable); + mutex_lock ml(*variable->mu()); + output_tensor() = variable->is_initialized; + } +}; + REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp").Device(DEVICE_CPU), - IsResourceInitialized); + VarIsInitializedOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER(Name("VarIsInitializedOp") diff --git a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc index 44a817a5c76d31aa8bde25a5f608b75b81116355..c0fde8042e816c325475a36129fb71630f0ca7c6 100644 --- a/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc +++ b/tensorflow/core/kernels/sample_distorted_bounding_box_op.cc @@ -387,9 +387,9 @@ class SampleDistortedBoundingBoxV2Op : public OpKernel { OP_REQUIRES_OK( context, context->allocate_output(2, TensorShape({1, 1, 4}), &bboxes)); - typename TTypes::Tensor begin_data = begin->tensor(); - typename TTypes::Tensor size_data = size->tensor(); - typename TTypes::Tensor bboxes_data = bboxes->tensor(); + typename TTypes::Tensor begin_data(begin->tensor()); + typename TTypes::Tensor size_data(size->tensor()); + TTypes::Tensor bboxes_data = bboxes->tensor(); begin_data(0) = T(offset_height); size_data(0) = T(target_height); diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h index 51814273b305bfa35bca0ddce0376658064ea56a..fe0a2782f952386e673127776c8f20da3ab1e2d5 100644 --- a/tensorflow/core/kernels/segment_reduction_ops.h +++ b/tensorflow/core/kernels/segment_reduction_ops.h @@ -16,6 +16,14 @@ limitations under the License. #ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ #define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_SEGMENT_REDUCTION_OPS_H_ + +// This file requires the following include because it uses CudaAtomicMax: +// #include "tensorflow/core/util/cuda_kernel_helper.h" + +// Unfortunately we can't add the #include, since it breaks compilation for +// non-GPU targets. This only breaks in clang, because it's more strict for +// template code and CudaAtomicMax is used in template context. + #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" diff --git a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc index ba979e6bb216b649ff4fc3cefa7099ac9cbc1b91..3511c85f7174f8dab47ca3ba05f01d7c4f5110b8 100644 --- a/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc @@ -17,10 +17,13 @@ limitations under the License. #define EIGEN_USE_GPU +// We need to include cuda_kernel_helper.h before segment_reduction_ops.h +// See comment in segment_reduction_ops.h for more details. +#include "tensorflow/core/util/cuda_kernel_helper.h" + #include "tensorflow/core/kernels/segment_reduction_ops.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/util/cuda_device_functions.h" -#include "tensorflow/core/util/cuda_kernel_helper.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc index 799c574d1542c345c606c276b0cc24fe61a47bba..64e0a68c2c119394561e947c4cf37838defd2d39 100644 --- a/tensorflow/core/kernels/serialize_sparse_op.cc +++ b/tensorflow/core/kernels/serialize_sparse_op.cc @@ -44,6 +44,8 @@ class SerializeSparseOp : public OpKernel { explicit SerializeSparseOp(OpKernelConstruction* context) : OpKernel(context) {} + bool IsExpensive() override; + Status Initialize(Tensor* result); Status Serialize(const Tensor& input, T* result); @@ -82,6 +84,21 @@ class SerializeSparseOp : public OpKernel { } }; +// NOTE(mrry): We specialize the IsExpensive() method differently for +// the string and variant cases, because (i) the string version +// actually performs memory copies as part of its serialization (and +// is hence potentially expensive), and (ii) the variant version +// performs O(1) shallow copies (and hence is much cheaper than +// dispatching to another thread would be). +template <> +bool SerializeSparseOp::IsExpensive() { + return true; +} +template <> +bool SerializeSparseOp::IsExpensive() { + return false; +} + template <> Status SerializeSparseOp::Initialize(Tensor* result) { *result = Tensor(DT_STRING, TensorShape({3})); diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 79369fd4a9cc1668bc12cfdb466ad2ec2bbe8d11..77594479cb1252d311fbfea8572590b0b32faecd 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -358,11 +358,11 @@ class MklSliceOp : public OpKernel { /* data format = NCHW */ #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { + for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { T* ip = in_buf + (d0 * in_strides[0]); T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { + for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { T* ip1 = ip + (d1 * in_strides[1]); T* op1 = op + ((d1 - begin[1]) * out_strides[1]); // For NCHW, H and W will be contiguous. So we can copy @@ -376,15 +376,15 @@ class MklSliceOp : public OpKernel { /* data_format = NHWC */ #pragma omp parallel for - for (size_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { + for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { T* ip = in_buf + (d0 * in_strides[0]); T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); #pragma omp parallel for - for (size_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { + for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { T* ip1 = ip + (d1 * in_strides[1]); T* op1 = op + ((d1 - begin[1]) * out_strides[1]); #pragma omp parallel for - for (size_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { + for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { T* ip2 = ip1 + (d2 * in_strides[2]); T* ip3 = ip2 + begin[3]; T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]); diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h index 2c79893b49661519515a7b4a537ff3caeceba2be..b94834f15988a21ad41eefc8030b3da1a58875f8 100644 --- a/tensorflow/core/kernels/snapshot_op.h +++ b/tensorflow/core/kernels/snapshot_op.h @@ -35,12 +35,17 @@ class SnapshotOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); Tensor* output = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output(0, input.shape(), &output)); - const Device& device = context->eigen_device(); - device.memcpy(output->template flat().data(), - input.template flat().data(), - input.NumElements() * sizeof(Scalar)); + // Try to use buffer forwarding to avoid an explicit copy. + OP_REQUIRES_OK(context, context->forward_input_or_allocate_output( + {0}, 0, input.shape(), &output)); + if (!output->SharesBufferWith(input)) { + // We had to allocate a new buffer since the refcount on the input was + // greater than 1. Copy the input to the new buffer. + const Device& device = context->eigen_device(); + device.memcpy(output->template flat().data(), + input.template flat().data(), + input.NumElements() * sizeof(Scalar)); + } } }; diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 7745effe2abe94ba73a2f0d761210e07c62e499c..1e3e92a68a05123bafad77348e6811a14c303301 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -109,17 +109,27 @@ class StridedSliceOp : public OpKernel { if (is_identity) { VLOG(1) << "Strided slice identity "; Tensor tmp; - CHECK(tmp.CopyFrom(input, final_shape)); + OP_REQUIRES(context, tmp.CopyFrom(input, final_shape), + errors::Internal("Copy failed")); context->set_output(0, tmp); return; } // Optimization #2, slice is memory contiguous (only occurs in dim 0) if (slice_dim0 && IsDim0SliceAligned(input.shape(), begin[0], end[0])) { - CHECK_GE(input.dims(), 1); // Otherwise, is_identity should be true. + OP_REQUIRES(context, input.dims() >= 1, + errors::InvalidArgument( + "Input must have rank at least 1, got: ", input.dims())); + // Otherwise, is_identity should be true. VLOG(1) << "Strided slice dim 0: " << input.shape().DebugString(); + OP_REQUIRES( + context, begin[0] <= end[0], + errors::InvalidArgument("begin[0] (", begin[0], + ") must less or equal to end[0] (", end[0])); + Tensor slice = input.Slice(begin[0], end[0]); Tensor tmp; - CHECK(tmp.CopyFrom(input.Slice(begin[0], end[0]), final_shape)); + OP_REQUIRES(context, tmp.CopyFrom(slice, final_shape), + errors::Internal("Copy failed")); context->set_output(0, tmp); return; } @@ -238,7 +248,8 @@ class StridedSliceGradOp : public OpKernel { if (processing_shape.dims() == 0) { auto in = context->input(4); - CHECK(result->CopyFrom(in, processing_shape)); + OP_REQUIRES(context, result->CopyFrom(in, processing_shape), + errors::Internal("Copy failed")); return; } diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index e29f67297f9ce4a99898b256deda46ba95362904..22e45918a03833c784f23911061c5b049658ffbe 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -115,7 +115,7 @@ class SubstrOp : public OpKernel { Tensor input_buffer; OP_REQUIRES_OK(context, context->allocate_temp( DT_STRING, output_shape, &input_buffer)); - typename TTypes::Tensor input_bcast = + TTypes::Tensor input_bcast = input_buffer.shaped(bcast.result_shape()); input_bcast = input.broadcast(BCast::ToIndexArray<1>(bcast.x_bcast())); @@ -125,8 +125,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &pos_buffer)); - typename TTypes::Tensor pos_bcast = - pos_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor pos_bcast( + pos_buffer.shaped(bcast.result_shape())); pos_bcast = pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); @@ -135,8 +135,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &len_buffer)); - typename TTypes::Tensor len_bcast = - len_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor len_bcast( + len_buffer.shaped(bcast.result_shape())); len_bcast = len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); @@ -164,7 +164,7 @@ class SubstrOp : public OpKernel { Tensor input_buffer; OP_REQUIRES_OK(context, context->allocate_temp( DT_STRING, output_shape, &input_buffer)); - typename TTypes::Tensor input_bcast = + TTypes::Tensor input_bcast = input_buffer.shaped(bcast.result_shape()); input_bcast = input.broadcast(BCast::ToIndexArray<2>(bcast.x_bcast())); @@ -174,8 +174,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &pos_buffer)); - typename TTypes::Tensor pos_bcast = - pos_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor pos_bcast( + pos_buffer.shaped(bcast.result_shape())); pos_bcast = pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); @@ -184,8 +184,8 @@ class SubstrOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::v(), output_shape, &len_buffer)); - typename TTypes::Tensor len_bcast = - len_buffer.shaped(bcast.result_shape()); + typename TTypes::Tensor len_bcast( + len_buffer.shaped(bcast.result_shape())); len_bcast = len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 07befa27bc54631d30e413a15972c560655418e0..233aa03c32333e62281cb8ab71828649b4fabe7e 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -1228,11 +1228,8 @@ inline T FtrlCompute(const T& accum, const T& linear, const T& lr, const T& l1, quadratic = Eigen::numext::pow(accum, -lr_power) / lr + static_cast(2) * l2; } - if (Eigen::numext::abs(linear) > l1) { - return (l1 * sgn(linear) - linear) / quadratic; - } else { - return static_cast(0.0); - } + auto l1_reg_adjust = std::max(std::min(linear, l1), -l1); + return (l1_reg_adjust - linear) / quadratic; } } // namespace diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index 0ef8724b10e492373c7663a58420bfe236be7df7..31388e42904608f20edd48152330f9ad2fb7d0ca 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -223,6 +223,16 @@ class UniqueOp : public OpKernel { .Device(DEVICE_CPU) \ .TypeConstraint("T") \ .TypeConstraint("out_idx"), \ + UniqueOp); \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ + UniqueOp) \ + REGISTER_KERNEL_BUILDER(Name("UniqueWithCountsV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .TypeConstraint("out_idx"), \ UniqueOp) TF_CALL_REAL_NUMBER_TYPES(REGISTER_UNIQUE); REGISTER_UNIQUE(string) diff --git a/tensorflow/core/kernels/variable_ops.cc b/tensorflow/core/kernels/variable_ops.cc index 10ccc85b7cd63db7f8d329a4253784abed7174cf..7fd5809ca49eba6af24d7dafe3b34b7f2c238279 100644 --- a/tensorflow/core/kernels/variable_ops.cc +++ b/tensorflow/core/kernels/variable_ops.cc @@ -237,6 +237,7 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SYCL_KERNEL); IsVariableInitializedOp); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); +TF_CALL_int64(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h index 83134bad378bfef18c3e93be5cc3c6b70ab4f523..8b406e5311cc33db943c1875a940fb886174cf28 100644 --- a/tensorflow/core/kernels/variable_ops.h +++ b/tensorflow/core/kernels/variable_ops.h @@ -45,6 +45,14 @@ class Var : public ResourceBase { tensor_.shape().DebugString()); } + // Only used in the resource variable path. In resource variables, + // tensor.IsInitialized() can be true (i.e. have memory allocated to it) while + // there is not a good value there due to a race condition, and it's possible + // to stumble upon this during variable.initialized_value(). So it's best to + // just store directly whether the variable is initialized. + bool is_initialized = false; // GUARDED_BY(mu_) but annotalysis doesn't like + // it. + private: mutex mu_; Tensor tensor_; diff --git a/tensorflow/core/kernels/xsmm_conv2d.cc b/tensorflow/core/kernels/xsmm_conv2d.cc index 601704c8a70f0b18c611cf8cd10d140314f61dc4..f8c06988cbac021d1f0924ca274c8bee5e9272a5 100644 --- a/tensorflow/core/kernels/xsmm_conv2d.cc +++ b/tensorflow/core/kernels/xsmm_conv2d.cc @@ -16,7 +16,7 @@ limitations under the License. // Make this file empty (or nearly empty) so that it can be compiled even when // libxsmm is not available. -#ifndef TENSORFLOW_USE_LIBXSMM +#ifndef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #else @@ -27,17 +27,14 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(); #include #include -#if 0 -#include -#endif #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/core/threadpool.h" -#include "libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ #include "include/libxsmm_cpuid.h" #include "include/libxsmm_malloc.h" +#include "third_party/libxsmm/src/libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/ namespace tensorflow { @@ -176,8 +173,16 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, InputPtr input, FilterPtr filter, OutputPtr output) { #if defined(LIBXSMM_DETAILED_TIMING) - unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6, - l_tick7, l_tick8, l_tick9, l_tick10; + uint64 l_tick1; + uint64 l_tick2; + uint64 l_tick3; + uint64 l_tick4; + uint64 l_tick5; + uint64 l_tick6; + uint64 l_tick7; + uint64 l_tick8; + uint64 l_tick9; + uint64 l_tick10; l_tick1 = libxsmm_timer_tick(); #endif // setup scoped allocator, which adopts the allocator from the context @@ -360,7 +365,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, l_tick6 = libxsmm_timer_tick(); #endif -#if 1 BlockingCounter counter(num_threads); for (int i = 0; i < num_threads; ++i) { @@ -371,14 +375,6 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, }); } counter.Wait(); -#else -#pragma omp parallel - { - chk_libxsmm_err( - libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()), - "Worker"); - } -#endif #if defined(LIBXSMM_DETAILED_TIMING) l_tick7 = libxsmm_timer_tick(); @@ -465,6 +461,7 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx, return true; // Succeeded } +#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS template struct XsmmFwdConv2D { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, @@ -473,7 +470,9 @@ struct XsmmFwdConv2D { input, filter, output); } }; +#endif +#ifdef TENSORFLOW_USE_LIBXSMM_BACKWARD_CONVOLUTIONS template struct XsmmBkwInputConv2D { bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc, @@ -491,6 +490,7 @@ struct XsmmBkwFilterConv2D { input, filter, output); } }; +#endif } // namespace functor @@ -500,4 +500,4 @@ template struct functor::XsmmBkwFilterConv2D; } // namespace tensorflow -#endif // TENSORFLOW_USE_LIBXSMM +#endif // TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 3657243c5d38a2076c1ca2c2e5f31b488b5a281b..ebc56482699948974ad434b6ea76fe26e1a4a5c5 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -49,7 +49,7 @@ RecordWriterOptions RecordWriterOptions::CreateRecordWriterOptions( #endif // IS_SLIM_BUILD } else if (compression_type != compression::kNone) { LOG(ERROR) << "Unsupported compression_type:" << compression_type - << ". No comprression will be used."; + << ". No compression will be used."; } return options; } diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc index 77a3414442caa523ab7a92e3e63babf581030287..cba473927dd1fce30bbe690b4bfda1e382ca12c0 100644 --- a/tensorflow/core/lib/png/png_io.cc +++ b/tensorflow/core/lib/png/png_io.cc @@ -90,11 +90,8 @@ void WarningHandler(png_structp png_ptr, png_const_charp msg) { void StringReader(png_structp png_ptr, png_bytep data, png_size_t length) { DecodeContext* const ctx = bit_cast(png_get_io_ptr(png_ptr)); if (static_cast(ctx->data_left) < length) { - if (!ctx->error_condition) { - VLOG(1) << "PNG read decoding error"; - ctx->error_condition = true; - } memset(data, 0, length); + png_error(png_ptr, "More bytes requested to read than available"); } else { memcpy(data, ctx->data, length); ctx->data += length; diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 267ce88440080399aae783903503f0bbd025d8b4..2fab62ea5cae6280554d2106f8f77d46017180e7 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -1201,6 +1201,23 @@ REGISTER_OP("UniqueWithCounts") return Status::OK(); }); +REGISTER_OP("UniqueWithCountsV2") + .Input("x: T") + .Input("axis: Taxis") + .Output("y: T") + .Output("idx: out_idx") + .Output("count: out_idx") + .Attr("T: type") + .Attr("Taxis: {int32,int64} = DT_INT64") + .Attr("out_idx: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + auto uniq = c->Vector(InferenceContext::kUnknownDim); + c->set_output(0, uniq); + c->set_output(1, c->input(0)); + c->set_output(2, uniq); + return Status::OK(); + }); + namespace { Status ShapeShapeFn(InferenceContext* c) { diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index fc9e5b02a2253621203a47c5f7d1b7d311c82a97..dddde1624a4f4258beae212014302f2599879d75 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -11460,6 +11460,14 @@ op { type: "type" } } +op { + name: "ConsumeMutexLock" + input_arg { + name: "mutex_lock" + type: DT_VARIANT + } + is_stateful: true +} op { name: "ControlTrigger" } @@ -12814,28 +12822,6 @@ op { } } } -op { - name: "CriticalSectionOp" - output_arg { - name: "resource" - type: DT_RESOURCE - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } - attr { - name: "shared_name" - type: "string" - default_value { - s: "" - } - } - is_stateful: true -} op { name: "CropAndResize" input_arg { @@ -17433,78 +17419,6 @@ op { } } } -op { - name: "ExecuteInCriticalSection" - input_arg { - name: "critical_section" - type: DT_RESOURCE - } - input_arg { - name: "arguments" - type_list_attr: "Targuments" - } - output_arg { - name: "outputs" - type_list_attr: "output_types" - } - attr { - name: "f" - type: "func" - } - attr { - name: "Targuments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} -op { - name: "ExecuteInCriticalSection" - input_arg { - name: "critical_section" - type: DT_RESOURCE - } - input_arg { - name: "arguments" - type_list_attr: "Targuments" - } - output_arg { - name: "outputs" - type_list_attr: "output_types" - } - attr { - name: "f" - type: "func" - } - attr { - name: "Targuments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - } - is_stateful: true -} op { name: "Exit" input_arg { @@ -20556,6 +20470,65 @@ op { minimum: -1 } } +op { + name: "GeneratorDataset" + input_arg { + name: "init_func_other_args" + type_list_attr: "Tinit_func_args" + } + input_arg { + name: "next_func_other_args" + type_list_attr: "Tnext_func_args" + } + input_arg { + name: "finalize_func_other_args" + type_list_attr: "Tfinalize_func_args" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "init_func" + type: "func" + } + attr { + name: "next_func" + type: "func" + } + attr { + name: "finalize_func" + type: "func" + } + attr { + name: "Tinit_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tnext_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tfinalize_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "GetSessionHandle" input_arg { @@ -30112,6 +30085,40 @@ op { } is_stateful: true } +op { + name: "MutexLock" + input_arg { + name: "mutex" + type: DT_RESOURCE + } + output_arg { + name: "mutex_lock" + type: DT_VARIANT + } + is_stateful: true +} +op { + name: "MutexV2" + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} op { name: "Neg" input_arg { @@ -64366,6 +64373,14 @@ op { version: 3 } } +op { + name: "Timestamp" + output_arg { + name: "ts" + type: DT_DOUBLE + } + is_stateful: true +} op { name: "TopK" input_arg { @@ -65218,29 +65233,6 @@ op { } } } -op { - name: "UniqueDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} op { name: "UniqueV2" input_arg { diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 9e98f56c745a2b0b16531e2785e43ba8464d42b8..bdbbf6d7c32014678d8ad171df03c29a4a44f422 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -66,6 +66,23 @@ REGISTER_OP("SparseTensorSliceDataset") // stateful to inhibit constant folding. .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("GeneratorDataset") + .Input("init_func_other_args: Tinit_func_args") + .Input("next_func_other_args: Tnext_func_args") + .Input("finalize_func_other_args: Tfinalize_func_args") + .Output("handle: variant") + .Attr("init_func: func") + .Attr("next_func: func") + .Attr("finalize_func: func") + .Attr("Tinit_func_args: list(type) >= 0") + .Attr("Tnext_func_args: list(type) >= 0") + .Attr("Tfinalize_func_args: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("ZipDataset") .Input("input_datasets: N * variant") .Output("handle: variant") @@ -329,13 +346,6 @@ REGISTER_OP("CacheDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); -REGISTER_OP("UniqueDataset") - .Input("input_dataset: variant") - .Output("handle: variant") - .Attr("output_types: list(type) >= 1") - .Attr("output_shapes: list(shape) >= 1") - .SetShapeFn(shape_inference::ScalarShape); - REGISTER_OP("TextLineDataset") .Input("filenames: string") .Input("compression_type: string") diff --git a/tensorflow/core/ops/function_ops.cc b/tensorflow/core/ops/function_ops.cc index ada96fa1d2ddf79b2669fa3fc437ce7b872a2eb1..a6914d9383d2f5c623b17fb0b918c4907ed84175 100644 --- a/tensorflow/core/ops/function_ops.cc +++ b/tensorflow/core/ops/function_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -55,6 +56,7 @@ REGISTER_OP("_ListToArray") .Attr("Tin: list(type)") .Attr("T: type") .Attr("N: int >= 1") + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Converts a list of tensors to an array of tensors. )doc"); @@ -65,6 +67,7 @@ REGISTER_OP("_ArrayToList") .Attr("T: type") .Attr("N: int >= 1") .Attr("out_types: list(type)") + .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( Converts an array of tensors to a list of tensors. )doc"); diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc index d263dc25b29d5c867a10ef20ea1b39fa9b9662f1..fbde692e959769fca53c91fef649b18c248526a6 100644 --- a/tensorflow/core/ops/logging_ops.cc +++ b/tensorflow/core/ops/logging_ops.cc @@ -111,4 +111,9 @@ REGISTER_OP("MergeSummary") .Attr("N : int >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("Timestamp") + .Output("ts: float64") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); + } // end namespace tensorflow diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 45ff08f38b134f963460d15f949411a7f1619d0c..55be0519a797f29994945d9c2fa44d27a5f0ad0f 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4773,6 +4773,14 @@ op { type: "type" } } +op { + name: "ConsumeMutexLock" + input_arg { + name: "mutex_lock" + type: DT_VARIANT + } + is_stateful: true +} op { name: "ControlTrigger" } @@ -5465,28 +5473,6 @@ op { } } } -op { - name: "CriticalSectionOp" - output_arg { - name: "resource" - type: DT_RESOURCE - } - attr { - name: "container" - type: "string" - default_value { - s: "" - } - } - attr { - name: "shared_name" - type: "string" - default_value { - s: "" - } - } - is_stateful: true -} op { name: "CropAndResize" input_arg { @@ -7788,41 +7774,6 @@ op { } } } -op { - name: "ExecuteInCriticalSection" - input_arg { - name: "critical_section" - type: DT_RESOURCE - } - input_arg { - name: "arguments" - type_list_attr: "Targuments" - } - output_arg { - name: "outputs" - type_list_attr: "output_types" - } - attr { - name: "f" - type: "func" - } - attr { - name: "Targuments" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - } - is_stateful: true -} op { name: "Exit" input_arg { @@ -9656,6 +9607,65 @@ op { minimum: -1 } } +op { + name: "GeneratorDataset" + input_arg { + name: "init_func_other_args" + type_list_attr: "Tinit_func_args" + } + input_arg { + name: "next_func_other_args" + type_list_attr: "Tnext_func_args" + } + input_arg { + name: "finalize_func_other_args" + type_list_attr: "Tfinalize_func_args" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "init_func" + type: "func" + } + attr { + name: "next_func" + type: "func" + } + attr { + name: "finalize_func" + type: "func" + } + attr { + name: "Tinit_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tnext_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tfinalize_func_args" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} op { name: "GetSessionHandle" input_arg { @@ -14308,6 +14318,40 @@ op { } is_stateful: true } +op { + name: "MutexLock" + input_arg { + name: "mutex" + type: DT_RESOURCE + } + output_arg { + name: "mutex_lock" + type: DT_VARIANT + } + is_stateful: true +} +op { + name: "MutexV2" + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} op { name: "Neg" input_arg { @@ -30368,6 +30412,14 @@ op { explanation: "TileGrad has been replaced with reduce_sum" } } +op { + name: "Timestamp" + output_arg { + name: "ts" + type: DT_DOUBLE + } + is_stateful: true +} op { name: "TopK" input_arg { @@ -30778,29 +30830,6 @@ op { } } } -op { - name: "UniqueDataset" - input_arg { - name: "input_dataset" - type: DT_VARIANT - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } -} op { name: "UniqueV2" input_arg { diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 8dae7e1ff5f872c33dd56509c0349180cec78593..0d8cf78cc2a196cde4a77f53ce912c437648786a 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -211,7 +211,7 @@ REGISTER_OP("ResourceScatterUpdate") return Status::OK(); }); -REGISTER_OP("CriticalSectionOp") +REGISTER_OP("MutexV2") .Attr("container: string = ''") .Attr("shared_name: string = ''") .Output("resource: resource") @@ -221,24 +221,18 @@ REGISTER_OP("CriticalSectionOp") return Status::OK(); }); -REGISTER_OP("ExecuteInCriticalSection") - .Input("critical_section: resource") - .Input("arguments: Targuments") - .Output("outputs: output_types") - .Attr("f: func") - .Attr("Targuments: list(type) >= 0") - .Attr("output_types: list(type) >= 0") - .Attr("output_shapes: list(shape) >= 0") +REGISTER_OP("MutexLock") + .Input("mutex: resource") + .Output("mutex_lock: variant") + .SetIsStateful() .SetShapeFn([](InferenceContext* c) { - std::vector output_shapes; - TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); - for (int i = 0; i < output_shapes.size(); ++i) { - ShapeHandle s; - TF_RETURN_IF_ERROR( - c->MakeShapeFromPartialTensorShape(output_shapes[i], &s)); - c->set_output(i, s); - } + c->set_output(0, c->Scalar()); return Status::OK(); }); +REGISTER_OP("ConsumeMutexLock") + .Input("mutex_lock: variant") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { return Status::OK(); }); + } // namespace tensorflow diff --git a/tensorflow/core/ops/shape_function_test.cc b/tensorflow/core/ops/shape_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..120995f3aac7da4111d0404a64f322a50d30a491 --- /dev/null +++ b/tensorflow/core/ops/shape_function_test.cc @@ -0,0 +1,34 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); + +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/platform/test.h" + +// Test to ensure that all core ops have shape functions defined. This is done +// by looking at all ops registered in the test binary. + +namespace tensorflow { + +TEST(ShapeFunctionTest, RegisteredOpsHaveShapeFns) { + OpRegistry* op_registry = OpRegistry::Global(); + std::vector op_data; + op_registry->GetOpRegistrationData(&op_data); + for (const OpRegistrationData& op_reg_data : op_data) { + EXPECT_TRUE(op_reg_data.shape_inference_fn != nullptr) + << op_reg_data.op_def.name(); + } +} + +} // namespace tensorflow diff --git a/tensorflow/core/ops/spectral_ops.cc b/tensorflow/core/ops/spectral_ops.cc index 508cea3495a9e811d4d12bf022b0ddfdcb33d718..2790aee37e93d3915ff9cba80af2e7ddccf4774e 100644 --- a/tensorflow/core/ops/spectral_ops.cc +++ b/tensorflow/core/ops/spectral_ops.cc @@ -142,26 +142,32 @@ REGISTER_OP("IRFFT3D") REGISTER_OP("BatchFFT") .Input("input: complex64") .Output("output: complex64") + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(15, "Use FFT"); REGISTER_OP("BatchIFFT") .Input("input: complex64") .Output("output: complex64") + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(15, "Use IFFT"); REGISTER_OP("BatchFFT2D") .Input("input: complex64") .Output("output: complex64") + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(15, "Use FFT2D"); REGISTER_OP("BatchIFFT2D") .Input("input: complex64") .Output("output: complex64") + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(15, "Use IFFT2D"); REGISTER_OP("BatchFFT3D") .Input("input: complex64") .Output("output: complex64") + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(15, "Use FFT3D"); REGISTER_OP("BatchIFFT3D") .Input("input: complex64") .Output("output: complex64") + .SetShapeFn(shape_inference::UnknownShape) .Deprecated(15, "Use IFFT3D"); } // namespace tensorflow diff --git a/tensorflow/core/ops/word2vec_ops.cc b/tensorflow/core/ops/word2vec_ops.cc index ed685dcf0ae9a3c61a1db491751f7de4e981300d..e469771103925e107d2f8aeced6df9dfb56cbe24 100644 --- a/tensorflow/core/ops/word2vec_ops.cc +++ b/tensorflow/core/ops/word2vec_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" namespace tensorflow { @@ -33,7 +34,8 @@ REGISTER_OP("Skipgram") .Attr("batch_size: int") .Attr("window_size: int = 5") .Attr("min_count: int = 5") - .Attr("subsample: float = 1e-3"); + .Attr("subsample: float = 1e-3") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("NegTrain") .Deprecated(19, @@ -46,6 +48,7 @@ REGISTER_OP("NegTrain") .Input("lr: float") .SetIsStateful() .Attr("vocab_count: list(int)") - .Attr("num_negative_samples: int"); + .Attr("num_negative_samples: int") + .SetShapeFn(shape_inference::UnknownShape); } // end namespace tensorflow diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc index e00dbdb4ae5ef682369b345353e236a6084460ef..3631d9ddf99430372c11403dba56c14331a3db24 100644 --- a/tensorflow/core/platform/denormal.cc +++ b/tensorflow/core/platform/denormal.cc @@ -40,36 +40,51 @@ limitations under the License. namespace tensorflow { namespace port { -ScopedFlushDenormal::ScopedFlushDenormal() { +static void SetDenormalState(bool flush_zero_mode, bool denormals_zero_mode) { // For now, we flush denormals only on SSE 3. Other architectures such as ARM // can be added as needed. #ifdef DENORM_USE_INTRINSICS if (TestCPUFeature(SSE3)) { - // Save existing flags - flush_zero_mode_ = _MM_GET_FLUSH_ZERO_MODE() == _MM_FLUSH_ZERO_ON; - denormals_zero_mode_ = - _MM_GET_DENORMALS_ZERO_MODE() == _MM_DENORMALS_ZERO_ON; - - // Flush denormals to zero (the FTZ flag). - _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON); - - // Interpret denormal inputs as zero (the DAZ flag). - _MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON); + // Restore flags + _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode ? _MM_FLUSH_ZERO_ON + : _MM_FLUSH_ZERO_OFF); + _MM_SET_DENORMALS_ZERO_MODE(denormals_zero_mode ? _MM_DENORMALS_ZERO_ON + : _MM_DENORMALS_ZERO_OFF); } #endif } -ScopedFlushDenormal::~ScopedFlushDenormal() { +static std::pair GetDernormalState() { + // For now, we flush denormals only on SSE 3. Other architectures such as ARM + // can be added as needed. + #ifdef DENORM_USE_INTRINSICS if (TestCPUFeature(SSE3)) { - // Restore flags - _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode_ ? _MM_FLUSH_ZERO_ON - : _MM_FLUSH_ZERO_OFF); - _MM_SET_DENORMALS_ZERO_MODE(denormals_zero_mode_ ? _MM_DENORMALS_ZERO_ON - : _MM_DENORMALS_ZERO_OFF); + // Save existing flags + bool flush_zero_mode = _MM_GET_FLUSH_ZERO_MODE() == _MM_FLUSH_ZERO_ON; + bool denormals_zero_mode = + _MM_GET_DENORMALS_ZERO_MODE() == _MM_DENORMALS_ZERO_ON; + return {flush_zero_mode, denormals_zero_mode}; } #endif + return {false, false}; +} + +ScopedRestoreFlushDenormalState::ScopedRestoreFlushDenormalState() { + std::tie(flush_zero_mode_, denormals_zero_mode_) = GetDernormalState(); +} + +ScopedRestoreFlushDenormalState::~ScopedRestoreFlushDenormalState() { + SetDenormalState(flush_zero_mode_, denormals_zero_mode_); +} + +ScopedFlushDenormal::ScopedFlushDenormal() { + SetDenormalState(/*flush_zero_mode=*/true, /*denormals_zero_mode=*/true); +} + +ScopedDontFlushDenormal::ScopedDontFlushDenormal() { + SetDenormalState(/*flush_zero_mode=*/false, /*denormals_zero_mode=*/false); } } // namespace port diff --git a/tensorflow/core/platform/denormal.h b/tensorflow/core/platform/denormal.h index 5e34131a3b8d8ec5b74bf66add1567e4f5207a02..09bb0352a2f375fac73054ca516cee79905795c1 100644 --- a/tensorflow/core/platform/denormal.h +++ b/tensorflow/core/platform/denormal.h @@ -21,19 +21,41 @@ limitations under the License. namespace tensorflow { namespace port { +// Remembers the flush denormal state on construction and restores that same +// state on destruction. +class ScopedRestoreFlushDenormalState { + public: + ScopedRestoreFlushDenormalState(); + ~ScopedRestoreFlushDenormalState(); + + private: + bool flush_zero_mode_; + bool denormals_zero_mode_; + TF_DISALLOW_COPY_AND_ASSIGN(ScopedRestoreFlushDenormalState); +}; + // While this class is active, denormal floating point numbers are flushed // to zero. The destructor restores the original flags. class ScopedFlushDenormal { public: ScopedFlushDenormal(); - ~ScopedFlushDenormal(); private: - bool flush_zero_mode_; - bool denormals_zero_mode_; + ScopedRestoreFlushDenormalState restore_; TF_DISALLOW_COPY_AND_ASSIGN(ScopedFlushDenormal); }; +// While this class is active, denormal floating point numbers are not flushed +// to zero. The destructor restores the original flags. +class ScopedDontFlushDenormal { + public: + ScopedDontFlushDenormal(); + + private: + ScopedRestoreFlushDenormalState restore_; + TF_DISALLOW_COPY_AND_ASSIGN(ScopedDontFlushDenormal); +}; + } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h index 12120c4ab96ae8327864c46a8e0dc434b900e67e..0481b3687137c8b00fa84d33eb317a1a4f5be9df 100644 --- a/tensorflow/core/platform/platform.h +++ b/tensorflow/core/platform/platform.h @@ -43,10 +43,11 @@ limitations under the License. #elif defined(__arm__) #define PLATFORM_POSIX -// Require an outside macro to tell us if we're building for Raspberry Pi. -#if !defined(RASPBERRY_PI) +// Require an outside macro to tell us if we're building for Raspberry Pi or +// another ARM device that's not a mobile platform. +#if !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) #define IS_MOBILE_PLATFORM -#endif // !defined(RASPBERRY_PI) +#endif // !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) #else // If no platform specified, use: diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 582b232054b850a2ef5ab8f47c089eb35a7bb3cf..f3b27ea394d04770b612752328d5d571e6521cc6 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -25,6 +25,7 @@ limitations under the License. #endif #include +#include #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/demangle.h" @@ -149,11 +150,16 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) { string Demangle(const char* mangled) { return mangled; } double NominalCPUFrequency() { -#ifdef TENSORFLOW_USE_ABSL - return absl::base_internal::NominalCPUFrequency(); -#else + DWORD data; + DWORD data_size = sizeof(data); + #pragma comment(lib, "shlwapi.lib") // For SHGetValue(). + if (SUCCEEDED( + SHGetValueA(HKEY_LOCAL_MACHINE, + "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0", + "~MHz", nullptr, &data, &data_size))) { + return data * 1e6; // Value is MHz. + } return 1.0; -#endif } int64 AvailableRam() { diff --git a/tensorflow/core/protobuf/control_flow.proto b/tensorflow/core/protobuf/control_flow.proto index 2c9476a08ad946e7f019475055397fcd6cfbbc5a..3c05b4f0e22e5ce2104980ad4fa52c8d8ad57070 100644 --- a/tensorflow/core/protobuf/control_flow.proto +++ b/tensorflow/core/protobuf/control_flow.proto @@ -17,6 +17,15 @@ message ValuesDef { map external_values = 2; } +// Container for any kind of control flow context. Any other control flow +// contexts that are added below should also be added here. +message ControlFlowContextDef { + oneof ctxt { + CondContextDef cond_ctxt = 1; + WhileContextDef while_ctxt = 2; + } +} + // Protocol buffer representing a CondContext object. message CondContextDef { // Name of the context. @@ -33,6 +42,9 @@ message CondContextDef { // Values and external values in control flow context. ValuesDef values_def = 5; + + // Contexts contained inside this context (e.g. nested conds). + repeated ControlFlowContextDef nested_contexts = 6; } // Protocol buffer representing a WhileContext object. @@ -70,5 +82,8 @@ message WhileContextDef { // Optional name of the maximum_iterations tensor. string maximum_iterations_name = 11; - // Next available id: 12. + // Contexts contained inside this context (e.g. nested whiles). + repeated ControlFlowContextDef nested_contexts = 12; + + // Next available id: 13. } diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index dddadceeb5ffd4f52abbcfd03b491c585c6612be..875e4663db6f82a002fd72cbd09052ee2e0510a5 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -30,29 +30,44 @@ message RewriterConfig { } // Optimize tensor layouts (default is ON) + // e.g. This will try to use NCHW layout on GPU which is faster. Toggle layout_optimizer = 1; // Fold constants (default is ON) + // Statically infer the value of tensors when possible, and materialize the + // result using constants. Toggle constant_folding = 3; // Arithmetic optimizations (default is ON) + // e.g. Simplify arithmetic ops; merge ops with same value (like constants). Toggle arithmetic_optimization = 7; // Control dependency optimizations (default is ON). + // Remove redundant control dependencies, which may enable other optimization. Toggle dependency_optimization = 8; + // Loop optimizations (default is OFF). + Toggle loop_optimization = 9; // If true, don't remove unnecessary ops from the graph bool disable_model_pruning = 2; enum MemOptType { - // The default setting (currently disabled) + // The default setting (SCHEDULING and SWAPPING HEURISTICS only) DEFAULT_MEM_OPT = 0; // Disabled in the meta-optimizer. NO_MEM_OPT = 1; // Driven by manual op-level annotations. MANUAL = 2; + // Driven by heuristics. The behavior of these heuristics is subject to // change. Currently includes an experimental recomputation and swapping // heuristics. Manual annotations are respected, but additional nodes are // selected automatically. + + // Swapping heuristic will move a tensor from the GPU to the CPU and move + // it back when needed to reduce peak memory usage. SWAPPING_HEURISTICS = 4; + // Recomputation heuristics will recompute ops (such as Relu activation) + // during backprop instead of storing them, reducing peak memory usage. RECOMPUTATION_HEURISTICS = 5; + // Scheduling will split big ops such as AddN and try to enforce a schedule + // of the new computations that decreases peak memory usage. SCHEDULING_HEURISTICS = 6; // Use any combination of swapping and recomputation heuristics. HEURISTICS = 3; @@ -85,5 +100,8 @@ message RewriterConfig { // ("autoparallel"). Memory optimization passes ("memory") invoked here are // not configurable (in contrast to memory optimization passes through the // meta-optimizer) and act only on manual op annotations. + // + // Custom registered optimizers will be run after the base optimizers, in + // the order that they are specified. repeated string optimizers = 100; } diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 50bfa9126789033c617e22f25dbb76273fccfc60..7405e01e14494fb6e4e241f1a2b8bc33a4200fa7 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -24,7 +24,7 @@ limitations under the License. // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", // "-beta", "-rc", "-rc.1") -#define TF_VERSION_SUFFIX "-rc0" +#define TF_VERSION_SUFFIX "-rc1" #define TF_STR_HELPER(x) #x #define TF_STR(x) TF_STR_HELPER(x) diff --git a/tensorflow/core/user_ops/fact.cc b/tensorflow/core/user_ops/fact.cc index 3a4fc8115a7f91badfeda369a599b3dba3057c63..2e8b22a49b620d08aa4f13da35e847b362dd2b3a 100644 --- a/tensorflow/core/user_ops/fact.cc +++ b/tensorflow/core/user_ops/fact.cc @@ -15,10 +15,13 @@ limitations under the License. // An example Op. +#include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" -REGISTER_OP("Fact").Output("fact: string"); +REGISTER_OP("Fact") + .Output("fact: string") + .SetShapeFn(tensorflow::shape_inference::UnknownShape); class FactOp : public tensorflow::OpKernel { public: diff --git a/tensorflow/core/util/cuda_launch_config.h b/tensorflow/core/util/cuda_launch_config.h index 3ea33ee6cf2195cc0192c59d694672f0d4c69a56..81df7a51d703986b040b5d15e128139ae56c24fb 100644 --- a/tensorflow/core/util/cuda_launch_config.h +++ b/tensorflow/core/util/cuda_launch_config.h @@ -169,6 +169,30 @@ inline CudaLaunchConfig GetCudaLaunchConfig(int work_element_count, return config; } +// Calculate the Cuda launch config we should use for a kernel launch. This +// variant takes the resource limits of func into account to maximize occupancy. +// The returned launch config has thread_per_block set to fixed_block_size. +// REQUIRES: work_element_count > 0. +template +inline CudaLaunchConfig GetCudaLaunchConfigFixedBlockSize( + int work_element_count, const Eigen::GpuDevice& d, DeviceFunc func, + size_t dynamic_shared_memory_size, int fixed_block_size) { + CHECK_GT(work_element_count, 0); + CudaLaunchConfig config; + int block_count = 0; + + cudaError_t err = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &block_count, func, fixed_block_size, dynamic_shared_memory_size); + CHECK_EQ(err, cudaSuccess); + block_count = std::min(block_count * d.getNumCudaMultiProcessors(), + DivUp(work_element_count, fixed_block_size)); + + config.virtual_thread_count = work_element_count; + config.thread_per_block = fixed_block_size; + config.block_count = block_count; + return config; +} + struct Cuda2DLaunchConfig { dim3 virtual_thread_count = dim3(0, 0, 0); dim3 thread_per_block = dim3(0, 0, 0); @@ -236,20 +260,18 @@ inline Cuda3DLaunchConfig GetCuda3DLaunchConfig( block_size_limit); CHECK_EQ(err, cudaSuccess); - auto min3 = [](int a, int b, int c) { return std::min(a, std::min(b, c)); }; - - int threadsx = min3(xdim, thread_per_block, xthreadlimit); + int threadsx = std::min({xdim, thread_per_block, xthreadlimit}); int threadsy = - min3(ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit); + std::min({ydim, std::max(thread_per_block / threadsx, 1), ythreadlimit}); int threadsz = - min3(zdim, std::max(thread_per_block / (threadsx * threadsy), 1), - zthreadlimit); - - int blocksx = min3(block_count, DivUp(xdim, threadsx), xgridlimit); - int blocksy = - min3(DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit); - int blocksz = min3(DivUp(block_count, (blocksx * blocksy)), - DivUp(zdim, threadsz), zgridlimit); + std::min({zdim, std::max(thread_per_block / (threadsx * threadsy), 1), + zthreadlimit}); + + int blocksx = std::min({block_count, DivUp(xdim, threadsx), xgridlimit}); + int blocksy = std::min( + {DivUp(block_count, blocksx), DivUp(ydim, threadsy), ygridlimit}); + int blocksz = std::min({DivUp(block_count, (blocksx * blocksy)), + DivUp(zdim, threadsz), zgridlimit}); config.virtual_thread_count = dim3(xdim, ydim, zdim); config.thread_per_block = dim3(threadsx, threadsy, threadsz); diff --git a/tensorflow/core/util/events_writer.cc b/tensorflow/core/util/events_writer.cc index 23b00e23dd0e7054aaf0e4e442c60f1372ce2d5b..49507616ed8c6461f8d59d8899d93abb4ba58cd2 100644 --- a/tensorflow/core/util/events_writer.cc +++ b/tensorflow/core/util/events_writer.cc @@ -17,6 +17,7 @@ limitations under the License. #include // for NULL +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -35,10 +36,21 @@ EventsWriter::EventsWriter(const string& file_prefix) file_prefix_(file_prefix), num_outstanding_events_(0) {} -bool EventsWriter::InitIfNeeded() { +EventsWriter::~EventsWriter() { + Close().IgnoreError(); // Autoclose in destructor. +} + +Status EventsWriter::Init() { return InitWithSuffix(""); } + +Status EventsWriter::InitWithSuffix(const string& suffix) { + file_suffix_ = suffix; + return InitIfNeeded(); +} + +Status EventsWriter::InitIfNeeded() { if (recordio_writer_ != nullptr) { CHECK(!filename_.empty()); - if (FileHasDisappeared()) { + if (!FileStillExists().ok()) { // Warn user of data loss and let .reset() below do basic cleanup. if (num_outstanding_events_ > 0) { LOG(WARNING) << "Re-initialization, attempting to open a new file, " @@ -46,7 +58,7 @@ bool EventsWriter::InitIfNeeded() { } } else { // No-op: File is present and writer is initialized. - return true; + return Status::OK(); } } @@ -57,15 +69,12 @@ bool EventsWriter::InitIfNeeded() { static_cast(time_in_seconds), port::Hostname().c_str(), file_suffix_.c_str()); - Status s = env_->NewWritableFile(filename_, &recordio_file_); - if (!s.ok()) { - LOG(ERROR) << "Could not open events file: " << filename_ << ": " << s; - return false; - } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + env_->NewWritableFile(filename_, &recordio_file_), + "Creating writable file ", filename_); recordio_writer_.reset(new io::RecordWriter(recordio_file_.get())); if (recordio_writer_ == nullptr) { - LOG(ERROR) << "Could not create record writer"; - return false; + return errors::Unknown("Could not create record writer"); } num_outstanding_events_ = 0; VLOG(1) << "Successfully opened events file: " << filename_; @@ -77,21 +86,21 @@ bool EventsWriter::InitIfNeeded() { event.set_wall_time(time_in_seconds); event.set_file_version(strings::StrCat(kVersionPrefix, kCurrentVersion)); WriteEvent(event); - Flush(); + TF_RETURN_WITH_CONTEXT_IF_ERROR(Flush(), "Flushing first event."); } - return true; + return Status::OK(); } string EventsWriter::FileName() { if (filename_.empty()) { - InitIfNeeded(); + InitIfNeeded().IgnoreError(); } return filename_; } void EventsWriter::WriteSerializedEvent(StringPiece event_str) { if (recordio_writer_ == nullptr) { - if (!InitIfNeeded()) { + if (!InitIfNeeded().ok()) { LOG(ERROR) << "Write failed because file could not be opened."; return; } @@ -108,60 +117,51 @@ void EventsWriter::WriteEvent(const Event& event) { WriteSerializedEvent(record); } -bool EventsWriter::Flush() { - if (num_outstanding_events_ == 0) return true; +Status EventsWriter::Flush() { + if (num_outstanding_events_ == 0) return Status::OK(); CHECK(recordio_file_ != nullptr) << "Unexpected NULL file"; - if (!recordio_writer_->Flush().ok()) { - LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " - << filename_; - return false; - } + TF_RETURN_WITH_CONTEXT_IF_ERROR(recordio_writer_->Flush(), "Failed to flush ", + num_outstanding_events_, " to ", filename_); + TF_RETURN_WITH_CONTEXT_IF_ERROR(recordio_file_->Sync(), "Failed to sync ", + num_outstanding_events_, " to ", filename_); - // The FileHasDisappeared() condition is necessary because - // recordio_writer_->Sync() can return true even if the underlying + // The FileStillExists() condition is necessary because + // recordio_writer_->Sync() can return OK even if the underlying // file has been deleted. EventWriter.FileDeletionBeforeWriting // demonstrates this and will fail if the FileHasDisappeared() // condition is removed. // Also, we deliberately attempt to Sync() before checking for a // disappearing file, in case for some file system File::Exists() is // false after File::Open() but before File::Sync(). - if (!recordio_file_->Flush().ok() || !recordio_file_->Sync().ok() || - FileHasDisappeared()) { - LOG(ERROR) << "Failed to flush " << num_outstanding_events_ << " events to " - << filename_; - return false; - } + TF_RETURN_WITH_CONTEXT_IF_ERROR(FileStillExists(), "Failed to flush ", + num_outstanding_events_, " to ", filename_); VLOG(1) << "Wrote " << num_outstanding_events_ << " events to disk."; num_outstanding_events_ = 0; - return true; + return Status::OK(); } -bool EventsWriter::Close() { - bool return_value = Flush(); +Status EventsWriter::Close() { + Status status = Flush(); if (recordio_file_ != nullptr) { - Status s = recordio_file_->Close(); - if (!s.ok()) { - LOG(ERROR) << "Error when closing previous event file: " << filename_ - << ": " << s; - return_value = false; + Status close_status = recordio_file_->Close(); + if (!close_status.ok()) { + status = close_status; } recordio_writer_.reset(nullptr); recordio_file_.reset(nullptr); } num_outstanding_events_ = 0; - return return_value; + return status; } -bool EventsWriter::FileHasDisappeared() { +Status EventsWriter::FileStillExists() { if (env_->FileExists(filename_).ok()) { - return false; - } else { - // This can happen even with non-null recordio_writer_ if some other - // process has removed the file. - LOG(ERROR) << "The events file " << filename_ << " has disappeared."; - return true; + return Status::OK(); } + // This can happen even with non-null recordio_writer_ if some other + // process has removed the file. + return errors::Unknown("The events file ", filename_, " has disappeared."); } } // namespace tensorflow diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h index a1a8cf790d4e2735d705cc2050c14970e5bfab4a..5dbaf97af4ad145cb09009b44d6f93d1c270d17d 100644 --- a/tensorflow/core/util/events_writer.h +++ b/tensorflow/core/util/events_writer.h @@ -18,6 +18,8 @@ limitations under the License. #include #include + +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/record_writer.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/macros.h" @@ -43,7 +45,7 @@ class EventsWriter { // Note that it is not recommended to simultaneously have two // EventWriters writing to the same file_prefix. explicit EventsWriter(const string& file_prefix); - ~EventsWriter() { Close(); } // Autoclose in destructor. + ~EventsWriter(); // Sets the event file filename and opens file for writing. If not called by // user, will be invoked automatically by a call to FileName() or Write*(). @@ -51,11 +53,8 @@ class EventsWriter { // and is open this is a no-op. If on the other hand the file was opened, // but has since disappeared (e.g. deleted by another process), this will open // a new file with a new timestamp in its filename. - bool Init() { return InitWithSuffix(""); } - bool InitWithSuffix(const string& suffix) { - file_suffix_ = suffix; - return InitIfNeeded(); - } + Status Init(); + Status InitWithSuffix(const string& suffix); // Returns the filename for the current events file: // filename_ = [file_prefix_].out.events.[timestamp].[hostname][suffix] @@ -77,12 +76,12 @@ class EventsWriter { // be written too. // Close() calls Flush() and then closes the current events file. // Returns true only if both the flush and the closure were successful. - bool Flush(); - bool Close(); + Status Flush(); + Status Close(); private: - bool FileHasDisappeared(); // True if event_file_path_ does not exist. - bool InitIfNeeded(); + Status FileStillExists(); // OK if event_file_path_ exists. + Status InitIfNeeded(); Env* env_; const string file_prefix_; diff --git a/tensorflow/core/util/events_writer_test.cc b/tensorflow/core/util/events_writer_test.cc index a6286ea701f09b94fe18cb373a42b5a83aab893a..a75b26abc631eb782ba527f9d15ac25ce9f72b2b 100644 --- a/tensorflow/core/util/events_writer_test.cc +++ b/tensorflow/core/util/events_writer_test.cc @@ -112,7 +112,7 @@ TEST(EventWriter, WriteFlush) { string file_prefix = GetDirName("/writeflush_test"); EventsWriter writer(file_prefix); WriteFile(&writer); - EXPECT_TRUE(writer.Flush()); + TF_EXPECT_OK(writer.Flush()); string filename = writer.FileName(); VerifyFile(filename); } @@ -121,7 +121,7 @@ TEST(EventWriter, WriteClose) { string file_prefix = GetDirName("/writeclose_test"); EventsWriter writer(file_prefix); WriteFile(&writer); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); string filename = writer.FileName(); VerifyFile(filename); } @@ -143,7 +143,7 @@ TEST(EventWriter, FailFlush) { TF_EXPECT_OK(env()->FileExists(filename)); TF_ASSERT_OK(env()->DeleteFile(filename)); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); - EXPECT_FALSE(writer.Flush()); + EXPECT_FALSE(writer.Flush().ok()); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); } @@ -155,18 +155,18 @@ TEST(EventWriter, FailClose) { TF_EXPECT_OK(env()->FileExists(filename)); TF_ASSERT_OK(env()->DeleteFile(filename)); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); - EXPECT_FALSE(writer.Close()); + EXPECT_FALSE(writer.Close().ok()); EXPECT_EQ(errors::Code::NOT_FOUND, env()->FileExists(filename).code()); } TEST(EventWriter, InitWriteClose) { string file_prefix = GetDirName("/initwriteclose_test"); EventsWriter writer(file_prefix); - EXPECT_TRUE(writer.Init()); + TF_EXPECT_OK(writer.Init()); string filename0 = writer.FileName(); TF_EXPECT_OK(env()->FileExists(filename0)); WriteFile(&writer); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); string filename1 = writer.FileName(); EXPECT_EQ(filename0, filename1); VerifyFile(filename1); @@ -178,7 +178,7 @@ TEST(EventWriter, NameWriteClose) { string filename = writer.FileName(); TF_EXPECT_OK(env()->FileExists(filename)); WriteFile(&writer); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); VerifyFile(filename); } @@ -186,7 +186,7 @@ TEST(EventWriter, NameClose) { string file_prefix = GetDirName("/nameclose_test"); EventsWriter writer(file_prefix); string filename = writer.FileName(); - EXPECT_TRUE(writer.Close()); + TF_EXPECT_OK(writer.Close()); TF_EXPECT_OK(env()->FileExists(filename)); TF_ASSERT_OK(env()->DeleteFile(filename)); } @@ -199,9 +199,9 @@ TEST(EventWriter, FileDeletionBeforeWriting) { env()->SleepForMicroseconds( 2000000); // To make sure timestamp part of filename will differ. TF_ASSERT_OK(env()->DeleteFile(filename0)); - EXPECT_TRUE(writer.Init()); // Init should reopen file. + TF_EXPECT_OK(writer.Init()); // Init should reopen file. WriteFile(&writer); - EXPECT_TRUE(writer.Flush()); + TF_EXPECT_OK(writer.Flush()); string filename1 = writer.FileName(); EXPECT_NE(filename0, filename1); VerifyFile(filename1); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index db4c5c35e365ca4eed48e07cbae3ad83bcb28622..34db96075d45f690cffad44bcc08cdf17d6e68dc 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -1112,9 +1112,11 @@ inline void ForwardMklTensorInToOutWithMklShape(OpKernelContext* context, // Forward the MKL shape ONLY (used in elementwise and other ops where // we call the eigen implementation and MKL shape is not used) inline void ForwardMklMetaDataInToOut(OpKernelContext* context, - uint idx_data_in, uint idx_data_out) { - uint idx_meta_in = GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); - uint idx_meta_out = + uint32 idx_data_in, + uint32_t idx_data_out) { + uint32 idx_meta_in = + GetTensorMetaDataIndex(idx_data_in, context->num_inputs()); + uint32 idx_meta_out = GetTensorMetaDataIndex(idx_data_out, context->num_outputs()); if (IsRefType(context->input_dtype(idx_data_in))) { @@ -1126,7 +1128,7 @@ inline void ForwardMklMetaDataInToOut(OpKernelContext* context, // Set a dummy MKL shape (called when the output is in TF format) inline void SetDummyMklShapeOutput(OpKernelContext* context, - uint idx_data_out) { + uint32 idx_data_out) { MklShape mkl_shape_output; mkl_shape_output.SetMklTensor(false); AllocateOutputSetMklShape(context, idx_data_out, mkl_shape_output); diff --git a/tensorflow/docs_src/about/roadmap.md b/tensorflow/docs_src/about/roadmap.md index 3ee825ed400de93553bf69fee065fcf8ef13be4d..1f934acab69276d4c32393bb73632d978e0d15c3 100644 --- a/tensorflow/docs_src/about/roadmap.md +++ b/tensorflow/docs_src/about/roadmap.md @@ -1,37 +1,86 @@ # Roadmap -**Last updated: January 23, 2017** +**Last updated: Feb 15, 2018** -TensorFlow is a fast moving project. In order for the community to better -understand what the near future will bring, this document shares what we are -working on internally. Many of these features were requested by the community, -and we welcome -[contributions](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). +TensorFlow is a rapidly moving, community supported project. This document is intended +to provide guidance about priorities and focus areas of the core set of TensorFlow +developers and about functionality that can be expected in the upcoming releases of +TensorFlow. Many of these areas are driven by community use cases, and we welcome +further +[contributions](https://github.com/tensorflow/tensorflow/blob/master/CONTRIBUTING.md) +to TensorFlow. -The features on this list are targeted for the next few months. At this point, -we do not have timelines for these features. +The features below do not have concrete release dates. However, the majority can be +expected in the next one to two releases. -### Improve non-Python language support +### APIs +#### High Level APIs: +* Easy multi-GPU utilization with Estimators +* Easy-to-use high-level pre-made estimators for Gradient Boosted Trees, Time Series, and other models -* Support for adding gradient computation for graphs constructed in other - languages (C++, Java, Go etc.) +#### Eager Execution: +* Efficient utilization of multiple GPUs +* Distributed training (multi-machine) +* Performance improvements +* Simpler export to a GraphDef/SavedModel -### Making TensorFlow easier to use -* High-level APIs -* Well-maintained models showing best practices +#### Keras API: +* Better integration with tf.data (ability to call `model.fit` with data tensors) +* Full support for Eager Execution (both Eager support for the regular Keras API, and ability +to create Keras models Eager- style via Model subclassing) +* Better distribution/multi-GPU support and TPU support (including a smoother model-to-estimator workflow) -### Performance -* Speed and memory benchmarks -* Distributed full model benchmarks -* Performance and memory usage improvements +#### Official Models: +* A set of +[reference models](https://github.com/tensorflow/models/tree/master/official) +across image recognition, speech, object detection, and + translation that demonstrate best practices and serve as a starting point for + high-performance model development. + +#### Contrib: +* Deprecation notices added to parts of tf.contrib where preferred implementations exist outside of tf.contrib. +* As much as possible, large projects inside tf.contrib moved to separate repositories. +* The tf.contrib module will eventually be discontinued in its current form, experimental development will in future happen in other repositories. -### Core Features -* Automatic op placement ([#2126](https://github.com/tensorflow/tensorflow/issues/2126)) -* Support for graph-level functions + +#### Probabilistic Reasoning and Statistical Analysis: +* Rich set of tools for probabilistic and statistical analysis in tf.distributions + and tf.probability. These include new samplers, layers, optimizers, losses, and structured models +* Statistical tools for hypothesis testing, convergence diagnostics, and sample statistics +* Edward 2.0: High-level API for probabilistic programming ### Platforms -* OpenCL support ([#22](https://github.com/tensorflow/tensorflow/issues/22)) +#### TensorFlow Lite: +* Increased coverage of supported ops in TensorFlow Lite +* Easier conversion of a trained TensorFlow graph for use on TensorFlow Lite +* Support for GPU acceleration in TensorFlow Lite (iOS and Android) +* Support for hardware accelerators via Android NeuralNets API +* Improved CPU performance by quantization and other network optimizations (eg. pruning, distillation) +* Increased support for devices beyond Android and iOS (eg. RPi, Cortex-M) + +### Performance +#### Distributed TensorFlow: +* Multi-GPU support optimized for a variety of GPU topologies +* Improved mechanisms for distributing computations on several machines + +#### Optimizations: +* Mixed precision training support with initial example model and guide +* Native TensorRT support +* Int8 support for SkyLake via MKL +* Dynamic loading of SIMD-optimized kernels + +### Documentation and Usability: +* Updated documentation, tutorials and Getting Started guides +* Process to enable external contributions to tutorials, documentation, and blogs showcasing best practice use-cases of TensorFlow and high-impact applications + +### Community and Partner Engagement +#### Special Interest Groups: +* Mobilizing the community to work together in focused domains +* [tf-distribute](https://groups.google.com/a/tensorflow.org/forum/#!forum/tf-distribute) +: build and packaging of TensorFlow +* More to be identified and launched -### Community -* More educational resources -* Better integration of TensorFlow into the opensource big data ecosystem (e.g. -[#2655](https://github.com/tensorflow/tensorflow/issues/2655)) +#### Community: +* Incorporate public feedback on significant design decisions via a Request-for-Comment (RFC) process +* Formalize process for external contributions to land in TensorFlow and associated projects +* Grow global TensorFlow communities and user groups +* Collaborate with partners to co-develop and publish research papers diff --git a/tensorflow/docs_src/about/uses.md b/tensorflow/docs_src/about/uses.md index 8818177a288ef16ac1907a20ab563ee3d871f7fd..d646880bd350c42e463680a5c7eb0903f2c0a497 100644 --- a/tensorflow/docs_src/about/uses.md +++ b/tensorflow/docs_src/about/uses.md @@ -22,6 +22,14 @@ This section describes some of the current uses of the TensorFlow system. > TensorFlow, or even better, send us a pull request to add an entry to this > file. +* **Deep Speech** +

+ * **RankBrain**
  • **Organization**: Google
  • diff --git a/tensorflow/docs_src/get_started/checkpoints.md b/tensorflow/docs_src/get_started/checkpoints.md index dfa2110e691167f54e6ea8b7a832f0a88d0ec41a..4aa07c7f2a0b56aa6de6f42e30c364c348753a39 100644 --- a/tensorflow/docs_src/get_started/checkpoints.md +++ b/tensorflow/docs_src/get_started/checkpoints.md @@ -154,7 +154,7 @@ classifier = tf.estimator.DNNClassifier( The first time you call an Estimator's `train` method, TensorFlow saves a checkpoint to the `model_dir`. Each subsequent call to the Estimator's -`train`, `eval`, or `predict` method causes the following: +`train`, `evaluate`, or `predict` method causes the following: 1. The Estimator builds the model's [graph](https://developers.google.com/machine-learning/glossary/#graph) @@ -222,7 +222,7 @@ does not match the shape stored in checkpoint: [20] To run experiments in which you train and compare slightly different versions of a model, save a copy of the code that created each -`model-dir`, possibly by creating a separate git branch for each version. +`model_dir`, possibly by creating a separate git branch for each version. This separation will keep your checkpoints recoverable. ## Summary diff --git a/tensorflow/docs_src/get_started/custom_estimators.md b/tensorflow/docs_src/get_started/custom_estimators.md index 42a246678a054d637fea5a82a03ecb84ff412bd9..ae89b639b422f4bd9e36302cbe78c445d497aa10 100644 --- a/tensorflow/docs_src/get_started/custom_estimators.md +++ b/tensorflow/docs_src/get_started/custom_estimators.md @@ -213,7 +213,7 @@ is connected to every node in the preceding layer. Here's the relevant code: ``` * The `units` parameter defines the number of output neurons in a given layer. -* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#a) — +* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#activation_function) — [Relu](https://developers.google.com/machine-learning/glossary/#ReLU) in this case. diff --git a/tensorflow/docs_src/get_started/datasets_quickstart.md b/tensorflow/docs_src/get_started/datasets_quickstart.md index a8a2ab6e56130c7805d48477301c63d88f87489c..c972e5e555eea1fab5a67fdecf13264897785519 100644 --- a/tensorflow/docs_src/get_started/datasets_quickstart.md +++ b/tensorflow/docs_src/get_started/datasets_quickstart.md @@ -28,8 +28,8 @@ def train_input_fn(features, labels, batch_size): # Shuffle, repeat, and batch the examples. dataset = dataset.shuffle(1000).repeat().batch(batch_size) - # Build the Iterator, and return the read end of the pipeline. - return dataset.make_one_shot_iterator().get_next() + # Return the dataset. + return dataset ``` Let's look at this more closely. @@ -40,7 +40,7 @@ This function expects three arguments. Arguments expecting an "array" can accept nearly anything that can be converted to an array with `numpy.array`. One exception is [`tuple`](https://docs.python.org/3/tutorial/datastructures.html#tuples-and-sequences) -which has special meaning for `Datasets`. +which, as we will see, has special meaning for `Datasets`. * `features`: A `{'feature_name':array}` dictionary (or [`DataFrame`](https://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.html)) @@ -73,11 +73,12 @@ Let's walk through the `train_input_fn()`. ### Slices -In the simplest cases, @{tf.data.Dataset.from_tensor_slices} function takes an -array and returns a @{tf.data.Dataset} representing slices of the array. For -example, an array containing the @{$tutorials/layers$mnist training data} -has a shape of `(60000, 28, 28)`. Passing this to `from_tensor_slices` returns -a `Dataset` object containing 60000 slices, each one a 28x28 image. +The function starts by using the @{tf.data.Dataset.from_tensor_slices} function +to create a @{tf.data.Dataset} representing slices of the array. The array is +sliced across the first dimension. For example, an array containing the +@{$tutorials/layers$mnist training data} has a shape of `(60000, 28, 28)`. +Passing this to `from_tensor_slices` returns a `Dataset` object containing +60000 slices, each one a 28x28 image. The code that returns this `Dataset` is as follows: @@ -89,18 +90,24 @@ mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x) print(mnist_ds) ``` -This will print the following line, showing the @{$programmers_guide/tensors#shapes$shapes} and @{$programmers_guide/tensors#data_types$types} of the items in -the dataset. Note that the dataset does not know how many items it contains. +This will print the following line, showing the +@{$programmers_guide/tensors#shapes$shapes} and +@{$programmers_guide/tensors#data_types$types} of the items in +the dataset. Note that a `Dataset` does not know how many items it contains. ``` None ``` -The dataset above represents a collection of simple arrays, but datasets are -much more powerful than this. Datasets transparently handle any nested -combination of dictionaries or tuples. For example, ensuring that `features` -is a standard dictionary, you can then convert the dictionary of arrays to -a `Dataset` of dictionaries as follows: +The `Dataset` above represents a simple collection of arrays, but datasets are +much more powerful than this. A `Dataset` can transparently handle any nested +combination of dictionaries or tuples (or +[`namedtuple`](https://docs.python.org/2/library/collections.html#collections.namedtuple) +). + +For example after converting the iris `features` +to a standard python dictionary, you can then convert the dictionary of arrays +to a `Dataset` of dictionaries as follows: ``` python dataset = tf.data.Dataset.from_tensor_slices(dict(features)) @@ -124,9 +131,9 @@ and `types` of the `Dataset` take on the same structure. This dataset contains dictionaries of @{$programmers_guide/tensors#rank$scalars}, all of type `tf.float64`. -The first line of `train_input_fn` uses the same functionality, but adds -another level of structure. It creates a dataset containing -`(features, labels)` pairs. +The first line of the iris `train_input_fn` uses the same functionality, but +adds another level of structure. It creates a dataset containing +`(features_dict, label)` pairs. The following code shows that the label is a scalar with type `int64`: @@ -164,14 +171,14 @@ dataset = dataset.shuffle(1000).repeat().batch(batch_size) ``` The @{tf.data.Dataset.shuffle$`shuffle`} method uses a fixed-size buffer to -shuffle the items as they pass through. Setting a `buffer_size` greater than -the number of examples in the `Dataset` ensures that the data is completely -shuffled. The Iris data set only contains 150 examples. +shuffle the items as they pass through. In this case the `buffer_size` is +greater than the number of examples in the `Dataset`, ensuring that the data is +completely shuffled (The Iris data set only contains 150 examples). -The @{tf.data.Dataset.repeat$`repeat`} method has the `Dataset` restart when +The @{tf.data.Dataset.repeat$`repeat`} method restarts the `Dataset` when it reaches the end. To limit the number of epochs, set the `count` argument. -The @{tf.data.Dataset.repeat$`batch`} method collects a number of examples and +The @{tf.data.Dataset.batch$`batch`} method collects a number of examples and stacks them, to create batches. This adds a dimension to their shape. The new dimension is added as the first dimension. The following code uses the `batch` method on the MNIST `Dataset`, from earlier. This results in a @@ -213,35 +220,16 @@ print(dataset) ### Return - +At this point the `Dataset` contains `(features_dict, labels)` pairs. +This is the format expected by the `train` and `evaluate` methods, so the +`input_fn` returns the dataset. -The `train`, `evaluate`, and `predict` methods of every Estimator require -input functions to return a `(features, label)` pair containing -@{$programmers_guide/tensors$tensorflow tensors}. The `train_input_fn` uses -the following line to convert the Dataset into the expected format: +The `labels` can/should be omitted when using the `predict` method. -```python -# Build the Iterator, and return the read end of the pipeline. -features_result, labels_result = dataset.make_one_shot_iterator().get_next() -``` + -The result is a structure of @{$programmers_guide/tensors$TensorFlow tensors}, -matching the layout of the items in the `Dataset`. -For an introduction to what these objects are and how to work with them, -see @{$programmers_guide/low_level_intro}. - -``` python -print((features_result, labels_result)) -``` - -```None -({ - 'SepalLength': , - 'PetalWidth': , - 'PetalLength': , - 'SepalWidth': }, -Tensor("IteratorGetNext_1:4", shape=(?,), dtype=int64)) -``` ## Reading a CSV File @@ -277,9 +265,6 @@ ds = tf.data.TextLineDataset(train_path).skip(1) ### Build a csv line parser -Ultimately we will need to parse each of the lines in the dataset, to -produce the necessary `(features, label)` pairs. - We will start by building a function to parse a single line. The following `iris_data.parse_line` function accomplishes this task using the diff --git a/tensorflow/docs_src/get_started/feature_columns.md b/tensorflow/docs_src/get_started/feature_columns.md index ad3e1fe3e3a4e3f5278e76bcaa0fc8eee2faf374..d8e4bec86357aabd2065be50d1197122c407c9d7 100644 --- a/tensorflow/docs_src/get_started/feature_columns.md +++ b/tensorflow/docs_src/get_started/feature_columns.md @@ -146,10 +146,10 @@ single input number into a four-element vector. Therefore, the model now can learn _four individual weights_ rather than just one; four weights creates a richer model than one weight. More importantly, bucketizing enables the model to clearly distinguish between different year categories since only one of the -elements is set (1) and the other three elements are cleared (0). When we just -use a single number (a year) as input, the model can only learn a linear -relationship. So, bucketing provides the model with additional flexibility that -the model can use to learn. +elements is set (1) and the other three elements are cleared (0). For example, +when we just use a single number (a year) as input, a linear model can only +learn a linear relationship. So, bucketing provides the model with additional +flexibility that the model can use to learn. The following code demonstrates how to create a bucketized feature: @@ -242,7 +242,7 @@ on an explicit vocabulary list. For example: # the elements in the vocabulary list. vocabulary_feature_column = tf.feature_column.categorical_column_with_vocabulary_list( - key="a feature returned by input_fn()", + key=feature_name_from_input_fn, vocabulary_list=["kitchenware", "electronics", "sports"]) ``` @@ -259,7 +259,7 @@ you place the vocabulary words in a separate file. For example: # the elements in the vocabulary file vocabulary_feature_column = tf.feature_column.categorical_column_with_vocabulary_file( - key="a feature returned by input_fn()", + key=feature_name_from_input_fn, vocabulary_file="product_class.txt", vocabulary_size=3) ``` diff --git a/tensorflow/docs_src/get_started/get_started_for_beginners.md b/tensorflow/docs_src/get_started/get_started_for_beginners.md index 9bca7540a73ea4354096de1b999ab708be26925c..b88483be699630d2275850cbc7c461eeb90f5943 100644 --- a/tensorflow/docs_src/get_started/get_started_for_beginners.md +++ b/tensorflow/docs_src/get_started/get_started_for_beginners.md @@ -36,6 +36,7 @@ the following three: alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor" src="../images/iris_three_species.jpg"> + **From left to right, [*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by [Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0), @@ -90,11 +91,10 @@ a number. Here's the representation scheme: A **model** is the relationship between features and the label. For the Iris problem, the model defines the relationship -between the sepal and petal measurements and the Iris species. -Some simple models can be described with a few lines of algebra; -more complex machine learning models -contain such a large number of interlacing mathematical functions and -parameters that they become hard to summarize mathematically. +between the sepal and petal measurements and the predicted Iris species. Some +simple models can be described with a few lines of algebra, but complex machine +learning models have a large number of parameters that are difficult to +summarize. Could you determine the relationship between the four features and the Iris species *without* using machine learning? That is, could you use @@ -188,6 +188,7 @@ provides a programming stack consisting of multiple API layers:
    + **The TensorFlow Programming Environment.**

     

    @@ -331,7 +332,7 @@ interpret data is such a rich topic that we devote an entire From a code perspective, you build a list of `feature_column` objects by calling functions from the @{tf.feature_column} module. Each object describes an input to the model. To tell the model to interpret data as a floating-point value, -call @{tf.feature_column.numeric_column). In `premade_estimator.py`, all +call @{tf.feature_column.numeric_column}. In `premade_estimator.py`, all four features should be interpreted as literal floating-point values, so the code to create a feature column looks as follows: @@ -380,6 +381,7 @@ fully connected neural network consisting of three hidden layers:
    + **A neural network with three hidden layers.**

     

    @@ -568,6 +570,7 @@ of 0.5. The following suggests a more effective model: 5.5 2.5 4.0 1.3 1 1 + **A model that is 80% accurate.**

     

    diff --git a/tensorflow/docs_src/get_started/premade_estimators.md b/tensorflow/docs_src/get_started/premade_estimators.md index 4f01f997c33c211e8cff81b6b268bb320aa794df..6bffd2e065548a42eb726df34542ecc7480ad38d 100644 --- a/tensorflow/docs_src/get_started/premade_estimators.md +++ b/tensorflow/docs_src/get_started/premade_estimators.md @@ -98,6 +98,7 @@ classifies Iris flowers into three different species based on the size of their alt="Petal geometry compared for three iris species: Iris setosa, Iris virginica, and Iris versicolor" src="../images/iris_three_species.jpg"> + **From left to right, [*Iris setosa*](https://commons.wikimedia.org/w/index.php?curid=170298) (by [Radomil](https://commons.wikimedia.org/wiki/User:Radomil), CC BY-SA 3.0), diff --git a/tensorflow/docs_src/install/index.md b/tensorflow/docs_src/install/index.md index 3c8488643f071c147dfbc4e0b4b4760b0a817718..4f85383925bbb8a03372b020e448a0e604f3b999 100644 --- a/tensorflow/docs_src/install/index.md +++ b/tensorflow/docs_src/install/index.md @@ -3,7 +3,7 @@ We've built and tested TensorFlow on the following 64-bit laptop/desktop operating systems: - * MacOS X 10.11 (El Capitan) or later. + * macOS 10.12.6 (Sierra) or later. * Ubuntu 16.04 or later * Windows 7 or later. diff --git a/tensorflow/docs_src/install/install_c.md b/tensorflow/docs_src/install/install_c.md index a783205b4a2d24182de6496e0173635990120185..818798555aec3a52bd5feb0c0e67d878a6dc41e4 100644 --- a/tensorflow/docs_src/install/install_c.md +++ b/tensorflow/docs_src/install/install_c.md @@ -15,7 +15,7 @@ instructions might also work on other variants, we have only tested following requirements: * Linux, 64-bit, x86 - * macOS X, Version 10.11 (El Capitan) or higher + * macOS X, Version 10.12.6 (Sierra) or higher ## Installation @@ -38,7 +38,7 @@ enable TensorFlow for C: OS="linux" # Change to "darwin" for macOS TARGET_DIRECTORY="/usr/local" curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-${OS}-x86_64-1.6.0-rc1.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_go.md b/tensorflow/docs_src/install/install_go.md index 5249e04615b506186a12807bb71ec4079db8156c..4c6dfa8dafe2042ea7b80498ca35a359f84ce854 100644 --- a/tensorflow/docs_src/install/install_go.md +++ b/tensorflow/docs_src/install/install_go.md @@ -17,7 +17,7 @@ instructions might also work on other variants, we have only tested following requirements: * Linux, 64-bit, x86 - * macOS X, 10.11 (El Capitan) or higher + * macOS X, 10.12.6 (Sierra) or higher ## Installation @@ -38,7 +38,7 @@ steps to install this library and enable TensorFlow for Go: TF_TYPE="cpu" # Change to "gpu" for GPU support TARGET_DIRECTORY='/usr/local' curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.6.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-${TF_TYPE}-$(go env GOOS)-x86_64-1.6.0-rc1.tar.gz" | sudo tar -C $TARGET_DIRECTORY -xz The `tar` command extracts the TensorFlow C library into the `lib` diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md index 0c6c773e62483b2272cf3b80da0932b4b800bb71..527884863ea5104e60569008ea067b407e74d29b 100644 --- a/tensorflow/docs_src/install/install_java.md +++ b/tensorflow/docs_src/install/install_java.md @@ -18,7 +18,7 @@ instructions might also work on other variants, we have only tested following requirements: * Ubuntu 16.04 or higher; 64-bit, x86 - * macOS X 10.11 (El Capitan) or higher + * macOS 10.12.6 (Sierra) or higher * Windows 7 or higher; 64-bit, x86 The installation instructions for Android are in a separate @@ -36,7 +36,7 @@ following to the project's `pom.xml` to use the TensorFlow Java APIs: org.tensorflow tensorflow - 1.6.0-rc0 + 1.6.0-rc1 ``` @@ -65,7 +65,7 @@ As an example, these steps will create a Maven project that uses TensorFlow: org.tensorflow tensorflow - 1.6.0-rc0 + 1.6.0-rc1 @@ -123,12 +123,12 @@ instead: org.tensorflow libtensorflow - 1.6.0-rc0 + 1.6.0-rc1 org.tensorflow libtensorflow_jni_gpu - 1.6.0-rc0 + 1.6.0-rc1 ``` @@ -147,7 +147,7 @@ refer to the simpler instructions above instead. Take the following steps to install TensorFlow for Java on Linux or macOS: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc1.jar), which is the TensorFlow Java Archive (JAR). 2. Decide whether you will run TensorFlow for Java on CPU(s) only or with @@ -166,7 +166,7 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: OS=$(uname -s | tr '[:upper:]' '[:lower:]') mkdir -p ./jni curl -L \ - "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc0.tar.gz" | + "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.6.0-rc1.tar.gz" | tar -xz -C ./jni ### Install on Windows @@ -174,10 +174,10 @@ Take the following steps to install TensorFlow for Java on Linux or macOS: Take the following steps to install TensorFlow for Java on Windows: 1. Download - [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc0.jar), + [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.6.0-rc1.jar), which is the TensorFlow Java Archive (JAR). 2. Download the following Java Native Interface (JNI) file appropriate for - [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc0.zip). + [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.6.0-rc1.zip). 3. Extract this .zip file. @@ -225,7 +225,7 @@ must be part of your `classpath`. For example, you can include the downloaded `.jar` in your `classpath` by using the `-cp` compilation flag as follows: -
    javac -cp libtensorflow-1.6.0-rc0.jar HelloTF.java
    +
    javac -cp libtensorflow-1.6.0-rc1.jar HelloTF.java
    ### Running @@ -239,11 +239,11 @@ two files are available to the JVM: For example, the following command line executes the `HelloTF` program on Linux and macOS X: -
    java -cp libtensorflow-1.6.0-rc0.jar:. -Djava.library.path=./jni HelloTF
    +
    java -cp libtensorflow-1.6.0-rc1.jar:. -Djava.library.path=./jni HelloTF
    And the following command line executes the `HelloTF` program on Windows: -
    java -cp libtensorflow-1.6.0-rc0.jar;. -Djava.library.path=jni HelloTF
    +
    java -cp libtensorflow-1.6.0-rc1.jar;. -Djava.library.path=jni HelloTF
    d If the program prints Hello from version, you've successfully installed TensorFlow for Java and are ready to use the API. If the program diff --git a/tensorflow/docs_src/install/install_linux.md b/tensorflow/docs_src/install/install_linux.md index 105b225177315db07b1117c3ece4b77dd2b60cb2..e3e115d9f618265864363810acf96033882ad89d 100644 --- a/tensorflow/docs_src/install/install_linux.md +++ b/tensorflow/docs_src/install/install_linux.md @@ -188,7 +188,7 @@ Take the following steps to install TensorFlow with Virtualenv: Virtualenv environment:
    (tensorflow)$ pip3 install --upgrade \
    -     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
    + https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl If you encounter installation problems, see [Common Installation Problems](#common_installation_problems). @@ -293,7 +293,7 @@ take the following steps:
          $ sudo pip3 install --upgrade \
    -     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
    +     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
          
    If this step fails, see @@ -480,8 +480,7 @@ Take the following steps to install TensorFlow in an Anaconda environment:
          (tensorflow)$ pip install --ignore-installed --upgrade \
    -     https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
    - + https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
    ## Validate your installation @@ -648,14 +647,14 @@ This section documents the relevant values for Linux installations. CPU only:
    -https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp27-none-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp27-none-linux_x86_64.whl
     
    GPU support:
    -https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp27-none-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp27-none-linux_x86_64.whl
     
    Note that GPU support requires the NVIDIA hardware and software described in @@ -667,14 +666,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
    -https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
     
    GPU support:
    -https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp34-cp34m-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp34-cp34m-linux_x86_64.whl
     
    Note that GPU support requires the NVIDIA hardware and software described in @@ -686,14 +685,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
    -https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
     
    GPU support:
    -https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp35-cp35m-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp35-cp35m-linux_x86_64.whl
     
    @@ -705,14 +704,14 @@ Note that GPU support requires the NVIDIA hardware and software described in CPU only:
    -https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/cpu/tensorflow-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
     
    GPU support:
    -https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc0-cp36-cp36m-linux_x86_64.whl
    +https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow_gpu-1.6.0rc1-cp36-cp36m-linux_x86_64.whl
     
    diff --git a/tensorflow/docs_src/install/install_mac.md b/tensorflow/docs_src/install/install_mac.md index a6ea548cfbdb3070c19b5c19ebc903ca76a4656a..623ca6bb7919bf74fa9bcaad3184cdf0bcd9ccff 100644 --- a/tensorflow/docs_src/install/install_mac.md +++ b/tensorflow/docs_src/install/install_mac.md @@ -5,7 +5,11 @@ instructions might also work on other macOS variants, we have only tested (and we only support) these instructions on machines meeting the following requirements: - * macOS X 10.11 (El Capitan) or higher + * macOS 10.12.6 (Sierra) or higher + +Note: There are known, accuracy-affecting numerical issues before macOS 10.12.6 +(Sierra) that are described in +[GitHub#15933](https://github.com/tensorflow/tensorflow/issues/15933#issuecomment-366331383). Note: As of version 1.2, TensorFlow no longer provides GPU support on macOS. @@ -114,8 +118,8 @@ Take the following steps to install TensorFlow with Virtualenv: Python 2.7, the command to install TensorFlow in the active Virtualenv is as follows: -
     $ pip3 install --upgrade \
    -     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl
    +
     $ pip install --upgrade \
    +     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl
    If you encounter installation problems, see [Common Installation Problems](#common-installation-problems). @@ -237,8 +241,8 @@ take the following steps: you are installing TensorFlow for Mac OS and Python 2.7 issue the following command: -
     $ sudo pip3 install --upgrade \
    -     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl 
    +
     $ sudo pip install --upgrade \
    +     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl 
    If the preceding command fails, see [installation problems](#common-installation-problems). @@ -347,7 +351,7 @@ Take the following steps to install TensorFlow in an Anaconda environment: TensorFlow for Python 2.7:
     (targetDirectory)$ pip install --ignore-installed --upgrade \
    -     https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl
    + https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl @@ -520,7 +524,7 @@ This section documents the relevant values for Mac OS installations.
    -https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-any.whl
    +https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py2-none-any.whl
     
    @@ -528,5 +532,5 @@ https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py2-none-a
    -https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc0-py3-none-any.whl
    +https://storage.googleapis.com/tensorflow/mac/cpu/tensorflow-1.6.0rc1-py3-none-any.whl
     
    diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md index 7853ec11f59632537ed1f9ebd3bc8f999dd088c7..acf0af0d9d558d58e625fdd315db859a5bd08121 100644 --- a/tensorflow/docs_src/install/install_sources.md +++ b/tensorflow/docs_src/install/install_sources.md @@ -359,10 +359,10 @@ Invoke `pip install` to install that pip package. The filename of the `.whl` file depends on your platform. For example, the following command will install the pip package -for TensorFlow 1.6.0rc0 on Linux: +for TensorFlow 1.6.0rc1 on Linux:
    -$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc0-py2-none-any.whl
    +$ sudo pip install /tmp/tensorflow_pkg/tensorflow-1.6.0rc1-py2-none-any.whl
     
    ## Validate your installation @@ -393,8 +393,7 @@ TensorFlow programs:
    Hello, TensorFlow!
    -If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with -TensorFlow}. +If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). @@ -460,8 +459,8 @@ Stack Overflow and specify the `tensorflow` tag. **Linux** - - + + @@ -479,7 +478,7 @@ Stack Overflow and specify the `tensorflow` tag. **Mac**
    Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
    tensorflow-1.6.0rc0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.0N/AN/A
    tensorflow_gpu-1.6.0rc0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
    tensorflow-1.6.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.0N/AN/A
    tensorflow_gpu-1.6.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.9.079
    tensorflow-1.5.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.0N/AN/A
    tensorflow_gpu-1.5.0GPU2.7, 3.3-3.6GCC 4.8Bazel 0.8.079
    tensorflow-1.4.0CPU2.7, 3.3-3.6GCC 4.8Bazel 0.5.4N/AN/A
    - + @@ -493,8 +492,8 @@ Stack Overflow and specify the `tensorflow` tag. **Windows**
    Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
    tensorflow-1.6.0rc0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
    tensorflow-1.6.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
    tensorflow-1.5.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.8.1N/AN/A
    tensorflow-1.4.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.5.4N/AN/A
    tensorflow-1.3.0CPU2.7, 3.3-3.6Clang from xcodeBazel 0.4.5N/AN/A
    - - + + diff --git a/tensorflow/docs_src/install/install_windows.md b/tensorflow/docs_src/install/install_windows.md index 657d37f6bcb953a2faa7cc93bdbb716a57788db8..f0a30ee39448c09d0125f17cc2eaaaee9ab6c1bb 100644 --- a/tensorflow/docs_src/install/install_windows.md +++ b/tensorflow/docs_src/install/install_windows.md @@ -47,7 +47,7 @@ installed on your system: If you have a different version of one of the preceding packages, please change to the specified versions. In particular, the cuDNN version -must match exactly: TensorFlow will not load if it cannot find `cuDNN64_6.dll`. +must match exactly: TensorFlow will not load if it cannot find `cuDNN64_7.dll`. To use a different version of cuDNN, you must build from source. ## Determine how to install TensorFlow @@ -153,8 +153,7 @@ TensorFlow programs:
    Hello, TensorFlow!
    -If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with -TensorFlow}. +If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Started with TensorFlow}. If the system outputs an error message instead of a greeting, see [Common installation problems](#common_installation_problems). diff --git a/tensorflow/docs_src/mobile/mobile_intro.md b/tensorflow/docs_src/mobile/mobile_intro.md index 17dbf1c3e6ad89768529864ba884274a51b3dfb2..69b63ae7d22ced9fd0299f17d1ae2d614c9a6be7 100644 --- a/tensorflow/docs_src/mobile/mobile_intro.md +++ b/tensorflow/docs_src/mobile/mobile_intro.md @@ -235,7 +235,7 @@ TensorFlow [on Github](https://github.com/tensorflow/models) that you can look through. Lean towards the simplest model you can find, and try to get started as soon as you have even a small amount of labelled data, since you’ll get the best results when you’re able to iterate quickly. The shorter the time it takes to -try training a model and running it in s real application, the better overall +try training a model and running it in its real application, the better overall results you’ll see. It’s common for an algorithm to get great training accuracy numbers but then fail to be useful within a real application because there’s a mismatch between the dataset and real usage. Prototype end-to-end usage as soon diff --git a/tensorflow/docs_src/performance/datasets_performance.md b/tensorflow/docs_src/performance/datasets_performance.md index 4f95e17c3598c23645fad07441c267266e5ef34e..46b43b7673c561679e89fff0ae738b0e751fcff5 100644 --- a/tensorflow/docs_src/performance/datasets_performance.md +++ b/tensorflow/docs_src/performance/datasets_performance.md @@ -92,11 +92,11 @@ transform the data. Without pipelining, the CPU and the GPU/TPU sit idle much of the time: -![without pipelining](https://www.tensorflow.org/images/datasets_without_pipelining.png) +![without pipelining](/images/datasets_without_pipelining.png) With pipelining, idle time diminishes significantly: -![with pipelining](https://www.tensorflow.org/images/datasets_with_pipelining.png) +![with pipelining](/images/datasets_with_pipelining.png) The `tf.data` API provides a software pipelining mechanism through the @{tf.data.Dataset.prefetch} transformation, which can be used to decouple the @@ -139,7 +139,7 @@ multiple CPU cores. To make this possible, the `map` transformation provides the the following diagram illustrates the effect of setting `num_parallel_calls=2` to the `map` transformation: -![parallel map](https://www.tensorflow.org/images/datasets_parallel_map.png) +![parallel map](/images/datasets_parallel_map.png) Choosing the best value for the `num_parallel_calls` argument depends on your hardware, characteristics of your training data (such as its size and shape), @@ -213,7 +213,7 @@ number of datasets to overlap can be specified by the `cycle_length` argument. The following diagram illustrates the effect of supplying `cycle_length=2` to the `parallel_interleave` transformation: -![parallel io](https://www.tensorflow.org/images/datasets_parallel_io.png) +![parallel io](/images/datasets_parallel_io.png) To apply this change to our running example, change: diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 5431572db83a84c034c56656928bdc927e708dc9..b2190c5243adce58a5f1d64786391961fdef2130 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -123,7 +123,7 @@ Normalizes an array across batch and spatial dimensions. | `scale` | `ComputationDataHandle` | 1 dimensional array | : : : (\\(\gamma\\)) : | `offset` | `ComputationDataHandle` | 1 dimensional array | -: : : (\\(\beta\\ ) : +: : : (\\(\beta\\)) : | `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) | | `feature_index` | `int64` | Index to feature dimension | : : : in `operand` : @@ -135,8 +135,8 @@ element in `operand`. The `feature_index` must be a valid index for the feature dimension in `operand`. The algorithm goes as follows for each batch in `operand` \\(x\\) that -contains `m` elements with `w` and `h` as the size of spatial dimensions ( -assuming `operand` is an 4 dimensional array): +contains `m` elements with `w` and `h` as the size of spatial dimensions +(assuming `operand` is an 4 dimensional array): - Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension: \\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\) @@ -170,7 +170,7 @@ Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast operation from a data shape to a target shape. The dimensions must match, and the conversion is an element-wise one; e.g. `s32` elements become `f32` elements via bitcast routine. Bitcast is implemented as a low-level cast, so machines -with different floating point representations will give different results. +with different floating-point representations will give different results. `BitcastConvertType(operand, new_element_type)` @@ -351,7 +351,7 @@ each other) and contains the arguments in the order that they were specified. : : : concatenated between the `operands`. : With the exception of `dimension` all dimensions must be the same. This is -because XLA does not support "ragged" arrays Also note that rank-0 values +because XLA does not support "ragged" arrays. Also note that rank-0 values cannot be concatenated (as it's impossible to name the dimension along which the concatenation occurs). @@ -440,11 +440,13 @@ area and a computation is performed for each possible position of the window. | `lhs` | `ComputationDataHandle` | rank n+2 array of inputs | | `rhs` | `ComputationDataHandle` | rank n+2 array of kernel | : : : weights : -| `window_strides` | `ArraySlice` | n-d array of kernel strides | -| `padding` | `ArraySlice` | size n array of kernel strides| +| `padding` | `ArraySlice>` : padding : -| `lhs_dilation` | `ArraySlice` | n-d lhs dilation factor array | -| `rhs_dilation` | `ArraySlice` | n-d rhs dilation factor array | +| `lhs_dilation` | `ArraySlice` | size n lhs dilation factor | +: : : array | +| `rhs_dilation` | `ArraySlice` | size n rhs dilation factor +: : : array | Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2 array describing the base area. This is called the input, even though of course @@ -468,7 +470,7 @@ filter/kernel/window. The dimensions are, in this order: window that moves across the base area. The `window_strides` argument specifies the stride of the convolutional window -in the spatial dimensions. For example, if the stride in a the first spatial +in the spatial dimensions. For example, if the stride in the first spatial dimension is 3, then the window can only be placed at coordinates where the first spatial index is divisible by 3. @@ -942,7 +944,7 @@ expand the rank of the lower-rank operand up to the rank of the higher-rank operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to the dimensions of the higher-rank shape. The unmapped dimensions of the expanded shape are filled with dimensions of size one. Degenerate-dimension broadcasting -then broadcasts the shapes along these degenerate dimension to equalize the +then broadcasts the shapes along these degenerate dimensions to equalize the shapes of both operands. The semantics are described in detail on the @{$broadcasting$broadcasting page}. @@ -1027,6 +1029,213 @@ Arguments | Type | Semantics The function is applied to each element in the `operand` array, resulting in an array with the same shape. It is allowed for `operand` to be a scalar (rank 0). +## Gather + +The XLA gather operation stitches together several slices (each slice at a +potentially different runtime offset) of an input tensor into an output tensor. + +### General Semantics + +See also +[`ComputationBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h). +For a more intuitive description, see the "Informal Description" section below. + + `gather(operand, gather_indices, output_window_dims, elided_window_dims, window_bounds, gather_dims_to_operand_dims)` + +|Arguments | Type | Semantics | +|----------------- | ----------------------- | --------------------------------| +|`operand` | `ComputationDataHandle` | The tensor we’re gathering | +: : : from. : +|`gather_indices` | `ComputationDataHandle` | Tensor containing the starting | +: : : indices of the slices we're : +: : : we're stitching together into : +: : : the output tensor. : +|`index_vector_dim` | `int64` | The dimension in | +: : : `gather_indices` that contains : +: : : the starting indices. : +|`output_window_dims` | `ArraySlice` | The set of dimensions in the | +: : : output shape that are _window : +: : : dimensions_ (defined below). : +: : : Not all window dimensions may : +: : : be present in the output shape. : +|`elided_window_dims` | `ArraySlice` | The set of _window dimensions_ | +: : : that are not present in the output shape. : +: : : `window_bounds[i]` must be `1` for all `i` : +: : : in `elided_window_dims`. : +|`window_bounds` | `ArraySlice` | `window_bounds[i]` is the bounds | +: : : for window dimension `i`. This includes : +: : : both the window dimensions that are : +: : : explicitly part of the output shape (via : +: : : `output_window_dims`) and the window : +: : : dimensions that are elided (via : +: : : `elided_window_dims`). : +|`gather_dims_to_operand_dims` | `ArraySlice` | A dimension map (the | +: : : array is interpreted as mapping `i` to : +: : : `gather_dims_to_operand_dims[i]`) from : +: : : the gather indices in `gather_indices` to : +: : : the operand index space. It has to be : +: : : one-to-one and total. : + +For every index `Out` in the output tensor, we compute two things (more +precisely described later): + + - An index into `gather_indices.rank` - `1` dimensions of `gather_indices`, + which gives us a starting index of a slice, _operand slice_, in the operand + tensor. These `gather_indices.rank` - `1` dimensions are all the dimensions + in `gather_indices` except `index_vector_dim`. + + - A _window index_ that has the same rank as the operand. This index is + composed of the values in `Out` at dimensions `output_window_dims`, embedded + with zeroes according to `elided_window_dims`. + +The _window index_ is the relative index of the element in _operand slice_ that +should be present in the output at index `Out`. + +The output is a tensor of rank `output_window_dims.size` + `gather_indices.rank` +- `1`. Additionally, as a shorthand, we define `output_gather_dims` of type +`ArraySlice` as the set of dimensions in the output shape but not in +`output_window_dims`, in ascending order. E.g. if the output tensor has rank +`5`, `output_window_dims` is {`2`, `4`} then `output_gather_dims` is {`0`, `1`, +`3`} + +If `index_vector_dim` is equal to `gather_indices.rank` we implicitly +consider `gather_indices` to have a trailing `1` dimension (i.e. if +`gather_indices` was of shape `[6,7]` and `index_vector_dim` is `2` then +we implicitly consider the shape of `gather_indices` to be `[6,7,1]`). + +The bounds for the output tensor along dimension `i` is computed as follows: + + 1. If `i` is present in `output_gather_dims` (i.e. is equal to + `output_gather_dims[k]` for some `k`) then we pick the corresponding + dimension bounds out of `gather_indices.shape`, skipping + `index_vector_dim` (i.e. pick `gather_indices.shape.dims`[`k`] if `k` + < `index_vector_dim` and `gather_indices.shape.dims`[`k`+`1`] + otherwise). + 2. If `i` is present in `output_window_dims` (i.e. equal to + `output_window_dims`[`k`] for some `k`) then we pick the corresponding + bound out of `window_bounds` after accounting for `elided_window_dims` + (i.e. we pick `adjusted_window_bounds`[`k`] where `adjusted_window_bounds` + is `window_bounds` with the bounds at indices `elided_window_dims` + removed). + +The operand index `In` corresponding to an output index `Out` is computed as +follows: + + 1. Let `G` = { `Out`[`k`] for `k` in `output_gather_dims` }. Use `G` to slice + out vector `S` such that `S`[`i`] = `gather_indices`[Combine(`G`, `i`)] + where Combine(A, b) inserts b at position `index_vector_dim` into A. + Note that this is well defined even if `G` is empty -- if `G` is empty then + `S` = `gather_indices`. + 2. Create an index, `S``in`, into `operand` using `S` by + scattering `S` using the `gather_dims_to_operand_dims` map + (`S``in` is the starting indices for _operand slice_ mentioned + above). More precisely: + 1. `S``in`[`gather_dims_to_operand_dims`[`k`]] = `S`[`k`] if `k` < + `gather_dims_to_operand_dims.size`. + 2. `S``in`[`_`] = `0` otherwise. + 3. Create an index `W``in` into `operand` by scattering the indices + at the output window dimensions in `Out` according to + the `elided_window_dims` set (`W``in` is the _window index_ + mentioned above). More precisely: + 1. `W``in`[`window_dims_to_operand_dims`(`k`)] = `Out`[`k`] if + `k` < `output_window_dims.size` (`window_dims_to_operand_dims` is + defined below). + 2. `W``in`[`_`] = `0` otherwise. + 4. `In` is `W``in` + `S``in` where + is element-wise + addition. + +`window_dims_to_operand_dims` is the monotonic function with domain [`0`, +`output_window_dims.size`) and range [`0`, `operand.rank`) \ +`elided_window_dims`. So if, e.g., `output_window_dims.size` is `4`, +`operand.rank` is `6` and `elided_window_dims` is {`0`, `2`} then +`window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}. + +### Informal Description and Examples + +`index_vector_dim` is set to `gather_indices.rank` - `1` in all of the +examples that follow. More interesting values for `index_vector_dim` +does not change the operation fundamentally, but makes the visual representation +more cumbersome. + +To get an intuition on how all of the above fits together, let's look at an +example that gathers 5 slices of shape `[8,6]` from a `[16,11]` tensor. The +position of a slice into the `[16,11]` tensor can be represented as an index +vector of shape `S64[2]`, so the set of 5 positions can be represented as a +`S64[5,2]` tensor. + +The behavior of the gather operation can then be depicted as an index +transformation that takes [`G`,`W``0`,`W``1`], an index in +the output shape, and maps it to an element in the input tensor in the following +way: + +
    + +
    + +We first select an (`X`,`Y`) vector from the gather indices tensor using `G`. +The element in the output tensor at index +[`G`,`W``0`,`W``1`] is then the element in the input +tensor at index [`X`+`W``0`,`Y`+`W``1`]. + +`window_bounds` is `[8,6]`, which decides the range of W`0` and +W`1`, and this in turn decides the bounds of the slice. + +This gather operation acts as a batch dynamic slice with `G` as the batch +dimension. + +The gather indices may be multidimensional. For instance, a more general +version of the example above using a "gather indices" tensor of shape `[4,5,2]` +would translate indices like this: + +
    + +
    + +Again, this acts as a batch dynamic slice `G``0` and +`G``1` as the batch dimensions. The window bounds are still `[8,6]`. + +The gather operation in XLA generalizes the informal semantics outlined above in +the following ways: + + 1. We can configure which dimensions in the output shape are the window + dimensions (dimensions containing `W``0`, `W``1` in + the last example). The output gather dimensions (dimensions containing + `G``0`, `G``1` in the last example) are defined to be + the output dimensions that are not window dimensions. + + 2. The number of output window dimensions explicitly present in the output + shape may be smaller than the input rank. These "missing" dimensions, which + are listed explicitly as `elided_window_dims`, must have a window bound of + `1`. Since they have a window bound of `1` the only valid index for them is + `0` and eliding them does not introduce ambiguity. + + 3. The slice extracted from the "Gather Indices" tensor ((`X`, `Y`) in the last + example) may have fewer elements than the input tensor rank, and an explicit + mapping dictates how the index should be expanded to have the same rank as + the input. + +As a final example, we use (2) and (3) to implement `tf.gather_nd`: + +
    + +
    + +`G``0` and `G``1` are used to slice out a starting index +from the gather indices tensor as usual, except the starting index has only one +element, `X`. Similarly, there is only one output window index with the value +`W``0`. However, before being used as indices into the input tensor, +these are expanded in accordance to "Gather Index Mapping" +(`gather_dims_to_operand_dims` in the formal description) and "Window Mapping" +(`window_dims_to_operand_dims` in the formal description) into +[`0`,`W``0`] and [`X`,`0`] respectively, adding up to +[`X`,`W``0`]. In other words, the output index +[`G``0`,`G``1`,`W``0`] maps to the input index +[`GatherIndices`[`G``0`,`G``1`,`0`],`X`] which gives us +the semantics for `tf.gather_nd`. + +`window_bounds` for this case is `[1,11]`. Intuitively this means that every +index `X` in the gather indices tensor picks an entire row and the result is the +concatenation of all these rows. ## GetTupleElement @@ -1081,7 +1290,7 @@ result2 = while (condition, init = result1) { ``` Nested tuple shapes are not supported. For an empty tuple shape, the Infeed -operation is effectively a nop and proceeds without reading any data from the +operation is effectively a no-op and proceeds without reading any data from the Infeed of the device. > Note: We plan to allow multiple Infeed operations without a total order, in @@ -1144,7 +1353,7 @@ dimension. `PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and -`interior_padding`. `edge_padding_low` and `edge_padding_high` specifies the +`interior_padding`. `edge_padding_low` and `edge_padding_high` specify the amount of padding added at the low-end (next to index 0) and the high-end (next to the highest index) of each dimension respectively. The amount of edge padding can be negative -- the absolute value of negative padding indicates the number @@ -1153,8 +1362,8 @@ the amount of padding added between any two elements in each dimension. Interior padding occurs logically before edge padding, so in the case of negative edge padding elements are removed from the interior-padded operand. This operation is a no-op if the edge padding pairs are all (0, 0) and the interior padding values -are all 0. Figure below shows examples of different `edge_padding` and -`interior_padding` values for a two dimensional array. +are all 0. The figure below shows examples of different `edge_padding` and +`interior_padding` values for a two-dimensional array.
    diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md index 9ede4ab83c1dcdb7370e83dfb9227fbb235d0689..d38fbddfa1cfad305b0549bd4a8ffda371c978b6 100644 --- a/tensorflow/docs_src/programmers_guide/datasets.md +++ b/tensorflow/docs_src/programmers_guide/datasets.md @@ -322,9 +322,39 @@ sess.run(iterator.initializer) next1, (next2, next3) = iterator.get_next() ``` -Note that evaluating *any* of `next1`, `next2`, or `next3` will advance the -iterator for all components. A typical consumer of an iterator will include all -components in a single expression. +Note that `next1`, `next2`, and `next3` are tensors produced by the +same op/node (created by `Iterator.get_next()`). Therefore, evaluating *any* of +these tensors will advance the iterator for all components. A typical consumer +of an iterator will include all components in a single expression. + +### Saving iterator state + +The @{tf.contrib.data.make_saveable_from_iterator} function creates a +`SaveableObject` from an iterator, which can be used to save and +restore the current state of the iterator (and, effectively, the whole input +pipeline). A saveable object thus created can be added to @{tf.train.Saver} +variables list or the `tf.GraphKeys.SAVEABLE_OBJECTS` collection for saving and +restoring in the same manner as a @{tf.Variable}. Refer to +@{$saved_model$Saving and Restoring} for details on how to save and restore +variables. + +```python +# Create saveable object from iterator. +saveable = tf.contrib.data.make_saveable_from_iterator(iterator) + +# Save the iterator state by adding it to the saveable objects collection. +tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable) +saver = tf.train.Saver() + +with tf.Session() as sess: + + if should_checkpoint: + saver.save(path_to_checkpoint) + +# Restore the iterator state. +with tf.Session() as sess: + saver.restore(sess, path_to_checkpoint) +``` ## Reading input data diff --git a/tensorflow/docs_src/programmers_guide/low_level_intro.md b/tensorflow/docs_src/programmers_guide/low_level_intro.md index 8f6d3fbd46d8b76d6033d95fd51c1df45733f5a3..05709ad10a9275953d351e4a62cbf6d7fbffbbe3 100644 --- a/tensorflow/docs_src/programmers_guide/low_level_intro.md +++ b/tensorflow/docs_src/programmers_guide/low_level_intro.md @@ -286,6 +286,23 @@ while True: break ``` +If the `Dataset` depends on stateful operations you may need to +initialize the iterator before using it, as shown below: + +``` python +r = tf.random_normal([10,3]) +dataset = tf.data.Dataset.from_tensor_slices(r) +iterator = dataset.make_initializable_iterator() +next_row = iterator.get_next() + +sess.run(iterator.initializer) +while True: + try: + print(sess.run(next_row)) + except tf.errors.OutOfRangeError: + break +``` + For more details on Datasets and Iterators see: @{$programmers_guide/datasets}. ## Layers @@ -295,7 +312,7 @@ the same input. @{tf.layers$Layers} are the preferred way to add trainable parameters to a graph. Layers package together both the variables and the operations that act -on them, . For example a +on them. For example a [densely-connected layer](https://developers.google.com/machine-learning/glossary/#fully_connected_layer) performs a weighted sum across all inputs for each output and applies an optional @@ -478,7 +495,7 @@ good. Here's what we got; your own output will almost certainly differ: [ 0.10527515]] ``` -### loss +### Loss To optimize a model, you first need to define the loss. We'll use the mean square error, a standard loss for regression problems. @@ -504,7 +521,7 @@ TensorFlow provides [**optimizers**](https://developers.google.com/machine-learning/glossary/#optimizer) implementing standard optimization algorithms. These are implemented as sub-classes of @{tf.train.Optimizer}. They incrementally change each -variable in order to minimizethe loss. The simplest optimization algorithm is +variable in order to minimize the loss. The simplest optimization algorithm is [**gradient descent**](https://developers.google.com/machine-learning/glossary/#gradient_descent), implemented by @{tf.train.GradientDescentOptimizer}. It modifies each variable according to the magnitude of the derivative of loss with respect to diff --git a/tensorflow/docs_src/programmers_guide/saved_model.md b/tensorflow/docs_src/programmers_guide/saved_model.md index f27a658342b8d33407e1c6ed5799a10c2305a74c..c54c278584ec6265f0da1453fc266aeec7cb6f30 100644 --- a/tensorflow/docs_src/programmers_guide/saved_model.md +++ b/tensorflow/docs_src/programmers_guide/saved_model.md @@ -3,6 +3,9 @@ This document explains how to save and restore @{$variables$variables} and models. +Important: TensorFlow model files are code. Be careful with untrusted code. +See [Using TensorFlow Securely](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/SECURITY.md) +for details. ## Saving and restoring variables @@ -694,15 +697,15 @@ executing the computation graph later. For example: $ saved_model_cli show --dir \ /tmp/saved_model_dir --tag_set serve --signature_def serving_default The given SavedModel SignatureDef contains the following input(s): -inputs['x'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x:0 + inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 The given SavedModel SignatureDef contains the following output(s): -outputs['y'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 + outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 Method name is: tensorflow/serving/predict ``` @@ -714,32 +717,32 @@ $ saved_model_cli show --dir /tmp/saved_model_dir --all MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['classify_x2_to_y3']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x2:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['scores'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y3:0 -Method name is: tensorflow/serving/classify + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 + Method name is: tensorflow/serving/classify ... signature_def['serving_default']: -The given SavedModel SignatureDef contains the following input(s): -inputs['x'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['y'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/predict + The given SavedModel SignatureDef contains the following input(s): + inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/predict ``` diff --git a/tensorflow/docs_src/programmers_guide/variables.md b/tensorflow/docs_src/programmers_guide/variables.md index 64250738056043e236b5eb236bcbf29375655260..e8cf7711552f4c83ed1e03e0753b580cc7505ddc 100644 --- a/tensorflow/docs_src/programmers_guide/variables.md +++ b/tensorflow/docs_src/programmers_guide/variables.md @@ -62,9 +62,10 @@ them. For this reason TensorFlow provides **collections**, which are named lists of tensors or other objects, such as `tf.Variable` instances. By default every `tf.Variable` gets placed in the following two collections: + * `tf.GraphKeys.GLOBAL_VARIABLES` --- variables that can be shared across -multiple devices, - * `tf.GraphKeys.TRAINABLE_VARIABLES`--- variables for which TensorFlow will + multiple devices, + * `tf.GraphKeys.TRAINABLE_VARIABLES` --- variables for which TensorFlow will calculate gradients. If you don't want a variable to be trainable, add it to the diff --git a/tensorflow/docs_src/programmers_guide/version_compat.md b/tensorflow/docs_src/programmers_guide/version_compat.md index a28f1385c87c7a083ee96977c5ab268c6977e17e..e6613cc69f8aedf344fa25b6564889e34cd9bf53 100644 --- a/tensorflow/docs_src/programmers_guide/version_compat.md +++ b/tensorflow/docs_src/programmers_guide/version_compat.md @@ -60,7 +60,8 @@ patch versions. The public APIs consist of * [`tensor_shape`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/tensor_shape.proto) * [`types`](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/framework/types.proto) -## What is *not* covered {not_covered} + +## What is *not* covered Some API functions are explicitly marked as "experimental" and can change in backward incompatible ways between minor releases. These include: diff --git a/tensorflow/docs_src/tutorials/layers.md b/tensorflow/docs_src/tutorials/layers.md index b898cbe29c2bac9ade341fe3b3566e42e133fc5b..5111b16247e2b5c3410e69dcdf08318a35b18c2f 100644 --- a/tensorflow/docs_src/tutorials/layers.md +++ b/tensorflow/docs_src/tutorials/layers.md @@ -635,7 +635,7 @@ should be logged after every 50 steps of training. ### Train the Model Now we're ready to train our model, which we can do by creating `train_input_fn` -ans calling `train()` on `mnist_classifier`. Add the following to `main()`: +and calling `train()` on `mnist_classifier`. Add the following to `main()`: ```python # Train the model diff --git a/tensorflow/docs_src/tutorials/wide.md b/tensorflow/docs_src/tutorials/wide.md index dba6f54c52ca5bf2569c66ad055329708de3991c..005dc020f94f666da295f4ff0342fae858121012 100644 --- a/tensorflow/docs_src/tutorials/wide.md +++ b/tensorflow/docs_src/tutorials/wide.md @@ -82,7 +82,7 @@ Here's a list of columns available in the Census Income dataset: | hours_per_week | Continuous | Hours worked per week. | | native_country | Categorical | Country of origin of the | : : : individual. : -| income | Categorical | ">50K" or "<=50K", meaning | +| income_bracket | Categorical | ">50K" or "<=50K", meaning | : : : whether the person makes more : : : : than $50,000 annually. : diff --git a/tensorflow/examples/android/res/animator/color_animation.xml b/tensorflow/examples/android/res/animator/color_animation.xml new file mode 100644 index 0000000000000000000000000000000000000000..891d8cc1d4f3e59d0371030fd763c5ad468e7887 --- /dev/null +++ b/tensorflow/examples/android/res/animator/color_animation.xml @@ -0,0 +1,30 @@ + + + + + diff --git a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java index 184df1bdb42802bfe50b15429f09baeb5600e34f..1cddf3dc5568babb8c08c690fad143299f5ccca5 100644 --- a/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java +++ b/tensorflow/examples/android/src/org/tensorflow/demo/SpeechActivity.java @@ -31,7 +31,8 @@ the RecognizeCommands helper class. package org.tensorflow.demo; -import android.animation.ValueAnimator; +import android.animation.AnimatorInflater; +import android.animation.AnimatorSet; import android.app.Activity; import android.content.pm.PackageManager; import android.media.AudioFormat; @@ -329,17 +330,13 @@ public class SpeechActivity extends Activity { labelIndex = i; } } - final View labelView = (View) labelsListView.getChildAt(labelIndex - 2); - ValueAnimator colorAnimation = - ValueAnimator.ofArgb(0x00b3ccff, 0xffb3ccff, 0x00b3ccff); - colorAnimation.setDuration(750); - colorAnimation.addUpdateListener( - new ValueAnimator.AnimatorUpdateListener() { - @Override - public void onAnimationUpdate(ValueAnimator animator) { - labelView.setBackgroundColor((int) animator.getAnimatedValue()); - } - }); + final View labelView = labelsListView.getChildAt(labelIndex - 2); + + AnimatorSet colorAnimation = + (AnimatorSet) + AnimatorInflater.loadAnimator( + SpeechActivity.this, R.animator.color_animation); + colorAnimation.setTarget(labelView); colorAnimation.start(); } } diff --git a/tensorflow/examples/get_started/regression/imports85.py b/tensorflow/examples/get_started/regression/imports85.py index 6bee556eb887a643b3a81691324736427ecc2707..4fdaceea9afee74550196031fe590c3a2abd20ed 100644 --- a/tensorflow/examples/get_started/regression/imports85.py +++ b/tensorflow/examples/get_started/regression/imports85.py @@ -131,11 +131,12 @@ def dataset(y_name="price", train_fraction=0.7): # booleans but we are dealing with symbolic tensors. return ~in_training_set(line) - base_dataset = (tf.contrib.data - # Get the lines from the file. - .TextLineDataset(path) - # drop lines with question marks. - .filter(has_no_question_marks)) + base_dataset = ( + tf.data + # Get the lines from the file. + .TextLineDataset(path) + # drop lines with question marks. + .filter(has_no_question_marks)) train = (base_dataset # Take only the training-set lines. diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 461fb1c5173f66278eb585d30bd8749a58fb6245..307eede5c03780e9244b035f020fc7846290d4d9 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ VALIDATION_FILE = 'validation.tfrecords' def decode(serialized_example): + """Parses an image and label from the given `serialized_example`.""" features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. @@ -66,6 +67,7 @@ def decode(serialized_example): def augment(image, label): + """Placeholder for data augmentation.""" # OPTIONAL: Could reshape into a 28x28 image and apply distortions # here. Since we are not applying any distortions in this # example, and the next step expects the image to be flattened @@ -74,9 +76,8 @@ def augment(image, label): def normalize(image, label): - # Convert from [0, 255] -> [-0.5, 0.5] floats. + """Convert `image` from [0, 255] -> [-0.5, 0.5] floats.""" image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 - return image, label @@ -106,18 +107,23 @@ def inputs(train, batch_size, num_epochs): if train else VALIDATION_FILE) with tf.name_scope('input'): - # TFRecordDataset opens a protobuf and reads entries line by line - # could also be [list, of, filenames] + # TFRecordDataset opens a binary file and reads one record at a time. + # `filename` could also be a list of filenames, which will be read in order. dataset = tf.data.TFRecordDataset(filename) - dataset = dataset.repeat(num_epochs) - # map takes a python function and applies it to every sample + # The map transformation takes a function and applies it to every element + # of the dataset. dataset = dataset.map(decode) dataset = dataset.map(augment) dataset = dataset.map(normalize) - #the parameter is the queue size + # The shuffle transformation uses a finite-sized buffer to shuffle elements + # in memory. The parameter is the number of elements in the buffer. For + # completely uniform shuffling, set the parameter to be the same as the + # number of elements in the dataset. dataset = dataset.shuffle(1000 + 3 * batch_size) + + dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) iterator = dataset.make_one_shot_iterator() @@ -153,7 +159,7 @@ def run_training(): sess.run(init_op) try: step = 0 - while True: #train until OutOfRangeError + while True: # Train until OutOfRangeError start_time = time.time() # Run one step of the model. The return values are diff --git a/tensorflow/examples/image_retraining/retrain.py b/tensorflow/examples/image_retraining/retrain.py index 58c5f87884e5a091300f128403d00fb90bad59fe..99a71206acbd533ec8bc5a9644435eacad564cd4 100644 --- a/tensorflow/examples/image_retraining/retrain.py +++ b/tensorflow/examples/image_retraining/retrain.py @@ -41,7 +41,6 @@ The subfolder names are important, since they define what label is applied to each image, but the filenames themselves don't matter. Once your images are prepared, you can run the training with a command like this: - ```bash bazel build tensorflow/examples/image_retraining:retrain && \ bazel-bin/tensorflow/examples/image_retraining/retrain \ @@ -70,17 +69,22 @@ on resource-limited platforms, you can try the `--architecture` flag with a Mobilenet model. For example: Run floating-point version of mobilenet: + ```bash python tensorflow/examples/image_retraining/retrain.py \ --image_dir ~/flower_photos --architecture mobilenet_1.0_224 ``` -Run quantized version of mobilenet: +Run mobilenet, instrumented for quantization: + ```bash python tensorflow/examples/image_retraining/retrain.py \ - --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quantized + --image_dir ~/flower_photos/ --architecture mobilenet_1.0_224_quant ``` +These instrumented models can be converted to fully quantized mobile models via +TensorFlow Lite. + There are 32 different Mobilenet models to choose from, with a variety of file size and latency options. The first number can be '1.0', '0.75', '0.50', or '0.25' to control the size, and the second controls the input image size, either @@ -96,6 +100,12 @@ Visualize the summaries with this command: tensorboard --logdir /tmp/retrain_logs +To use with Tensorflow Serving: + +```bash +tensorflow_model_server --port=9000 --model_name=inception \ + --model_base_path=/tmp/saved_models/ +``` """ from __future__ import absolute_import from __future__ import division @@ -114,7 +124,6 @@ import numpy as np from six.moves import urllib import tensorflow as tf -from tensorflow.contrib.quantize.python import quant_ops from tensorflow.python.framework import graph_util from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import gfile @@ -128,6 +137,9 @@ FLAGS = None # need to update these to reflect the values in the network you're using. MAX_NUM_IMAGES_PER_CLASS = 2 ** 27 - 1 # ~134M +# The location where variable checkpoints will be stored. +CHECKPOINT_NAME = '/tmp/_retrain_checkpoint' + def create_image_lists(image_dir, testing_percentage, validation_percentage): """Builds a list of training images from the file system. @@ -344,8 +356,8 @@ def maybe_download_and_extract(data_url): filepath, _ = urllib.request.urlretrieve(data_url, filepath, _progress) print() statinfo = os.stat(filepath) - tf.logging.info('Successfully downloaded %s %d bytes.', - filename, statinfo.st_size) + tf.logging.info('Successfully downloaded %s %d bytes.', filename, + statinfo.st_size) print('Extracting file from ', filepath) tarfile.open(filepath, 'r:gz').extractall(dest_directory) else: @@ -738,9 +750,9 @@ def variable_summaries(var): tf.summary.histogram('histogram', var) -def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, - bottleneck_tensor_size, quantize_layer): - """Adds a new softmax and fully-connected layer for training. +def add_final_retrain_ops(class_count, final_tensor_name, bottleneck_tensor, + bottleneck_tensor_size, quantize_layer, is_training): + """Adds a new softmax and fully-connected layer for training and eval. We need to retrain the top layer to identify our new classes, so this function adds the right operations to the graph, along with some variables to hold the @@ -756,7 +768,9 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, bottleneck_tensor: The output of the main CNN graph. bottleneck_tensor_size: How many entries in the bottleneck vector. quantize_layer: Boolean, specifying whether the newly added layer should be - quantized. + instrumented for quantized. + is_training: Boolean, specifying whether the newly add layer is for training + or eval. Returns: The tensors for the training and cross entropy results, and tensors for the @@ -771,50 +785,41 @@ def add_final_training_ops(class_count, final_tensor_name, bottleneck_tensor, ground_truth_input = tf.placeholder( tf.int64, [None], name='GroundTruthInput') - # Organizing the following ops as `final_training_ops` so they're easier - # to see in TensorBoard - layer_name = 'final_training_ops' + # Organizing the following ops so they are easier to see in TensorBoard. + layer_name = 'final_retrain_ops' with tf.name_scope(layer_name): with tf.name_scope('weights'): initial_value = tf.truncated_normal( [bottleneck_tensor_size, class_count], stddev=0.001) layer_weights = tf.Variable(initial_value, name='final_weights') - if quantize_layer: - quantized_layer_weights = quant_ops.MovingAvgQuantize( - layer_weights, is_training=True) - variable_summaries(quantized_layer_weights) - variable_summaries(layer_weights) + with tf.name_scope('biases'): layer_biases = tf.Variable(tf.zeros([class_count]), name='final_biases') - if quantize_layer: - quantized_layer_biases = quant_ops.MovingAvgQuantize( - layer_biases, is_training=True) - variable_summaries(quantized_layer_biases) - variable_summaries(layer_biases) with tf.name_scope('Wx_plus_b'): - if quantize_layer: - logits = tf.matmul(bottleneck_input, - quantized_layer_weights) + quantized_layer_biases - logits = quant_ops.MovingAvgQuantize( - logits, - init_min=-32.0, - init_max=32.0, - is_training=True, - num_bits=8, - narrow_range=False, - ema_decay=0.5) - tf.summary.histogram('pre_activations', logits) - else: - logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases - tf.summary.histogram('pre_activations', logits) + logits = tf.matmul(bottleneck_input, layer_weights) + layer_biases + tf.summary.histogram('pre_activations', logits) final_tensor = tf.nn.softmax(logits, name=final_tensor_name) + # The tf.contrib.quantize functions rewrite the graph in place for + # quantization. The imported model graph has already been rewritten, so upon + # calling these rewrites, only the newly added final layer will be + # transformed. + if quantize_layer: + if is_training: + tf.contrib.quantize.create_training_graph() + else: + tf.contrib.quantize.create_eval_graph() + tf.summary.histogram('activations', final_tensor) + # If this is an eval graph, we don't need to add loss ops or an optimizer. + if not is_training: + return None, None, bottleneck_input, ground_truth_input, final_tensor + with tf.name_scope('cross_entropy'): cross_entropy_mean = tf.losses.sparse_softmax_cross_entropy( labels=ground_truth_input, logits=logits) @@ -850,13 +855,91 @@ def add_evaluation_step(result_tensor, ground_truth_tensor): return evaluation_step, prediction -def save_graph_to_file(sess, graph, graph_file_name): +def run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor): + """Runs a final evaluation on an eval graph using the test data set. + + Args: + sess: Session for the train graph. + model_info: Model info dictionary from create_model_info() + class_count: Number of classes + image_lists: Dictionary of training images for each label. + jpeg_data_tensor: The layer to feed jpeg image data into. + decoded_image_tensor: The output of decoding and resizing the image. + resized_image_tensor: The input node of the recognition graph. + bottleneck_tensor: The bottleneck output layer of the CNN graph. + """ + (sess, bottleneck_input, ground_truth_input, evaluation_step, + prediction) = build_eval_session(model_info, class_count) + + test_bottlenecks, test_ground_truth, test_filenames = ( + get_random_cached_bottlenecks(sess, image_lists, FLAGS.test_batch_size, + 'testing', FLAGS.bottleneck_dir, + FLAGS.image_dir, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor, FLAGS.architecture)) + test_accuracy, predictions = sess.run( + [evaluation_step, prediction], + feed_dict={ + bottleneck_input: test_bottlenecks, + ground_truth_input: test_ground_truth + }) + tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % + (test_accuracy * 100, len(test_bottlenecks))) + + if FLAGS.print_misclassified_test_images: + tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===') + for i, test_filename in enumerate(test_filenames): + if predictions[i] != test_ground_truth[i]: + tf.logging.info('%70s %s' % (test_filename, + list(image_lists.keys())[predictions[i]])) + + +def build_eval_session(model_info, class_count): + """Builds an restored eval session without train operations for exporting. + + Args: + model_info: Model info dictionary from create_model_info() + class_count: Number of classes + + Returns: + Eval session containing the restored eval graph. + The bottleneck input, ground truth, eval step, and prediction tensors. + """ + # If quantized, we need to create the correct eval graph for exporting. + eval_graph, bottleneck_tensor, _ = create_model_graph(model_info) + + eval_sess = tf.Session(graph=eval_graph) + with eval_graph.as_default(): + # Add the new layer for exporting. + (_, _, bottleneck_input, + ground_truth_input, final_tensor) = add_final_retrain_ops( + class_count, FLAGS.final_tensor_name, bottleneck_tensor, + model_info['bottleneck_tensor_size'], model_info['quantize_layer'], + False) + + # Now we need to restore the values from the training graph to the eval + # graph. + tf.train.Saver().restore(eval_sess, CHECKPOINT_NAME) + + evaluation_step, prediction = add_evaluation_step(final_tensor, + ground_truth_input) + + return (eval_sess, bottleneck_input, ground_truth_input, evaluation_step, + prediction) + + +def save_graph_to_file(graph, graph_file_name, model_info, class_count): + """Saves an graph to file, creating a valid quantized one if necessary.""" + sess, _, _, _, _ = build_eval_session(model_info, class_count) + graph = sess.graph + output_graph_def = graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [FLAGS.final_tensor_name]) with gfile.FastGFile(graph_file_name, 'wb') as f: f.write(output_graph_def.SerializeToString()) - return def prepare_file_system(): @@ -909,11 +992,10 @@ def create_model_info(architecture): return None version_string = parts[1] if (version_string != '1.0' and version_string != '0.75' and - version_string != '0.50' and version_string != '0.25'): + version_string != '0.5' and version_string != '0.25'): tf.logging.error( - """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25', - but found '%s' for architecture '%s'""", - version_string, architecture) + """"The Mobilenet version should be '1.0', '0.75', '0.5', or '0.25', + but found '%s' for architecture '%s'""", version_string, architecture) return None size_string = parts[2] if (size_string != '224' and size_string != '192' and @@ -926,35 +1008,26 @@ def create_model_info(architecture): if len(parts) == 3: is_quantized = False else: - if parts[3] != 'quantized': + if parts[3] != 'quant': tf.logging.error( "Couldn't understand architecture suffix '%s' for '%s'", parts[3], architecture) return None is_quantized = True + data_url = 'http://download.tensorflow.org/models/mobilenet_v1_2018_02_22/' + model_name = 'mobilenet_v1_' + version_string + '_' + size_string if is_quantized: - data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' - data_url += version_string + '_' + size_string + '_quantized_frozen.tgz' - bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' - resized_input_tensor_name = 'Placeholder:0' - model_dir_name = ('mobilenet_v1_' + version_string + '_' + size_string + - '_quantized_frozen') - model_base_name = 'quantized_frozen_graph.pb' - - else: - data_url = 'http://download.tensorflow.org/models/mobilenet_v1_' - data_url += version_string + '_' + size_string + '_frozen.tgz' - bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' - resized_input_tensor_name = 'input:0' - model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string - model_base_name = 'frozen_graph.pb' + model_name += '_quant' + data_url += model_name + '.tgz' + bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0' + resized_input_tensor_name = 'input:0' + model_file_name = model_name + '_frozen.pb' bottleneck_tensor_size = 1001 input_width = int(size_string) input_height = int(size_string) input_depth = 3 - model_file_name = os.path.join(model_dir_name, model_base_name) input_mean = 127.5 input_std = 127.5 else: @@ -1004,6 +1077,47 @@ def add_jpeg_decoding(input_width, input_height, input_depth, input_mean, return jpeg_data, mul_image +def export_model(model_info, class_count, saved_model_dir): + """Exports model for serving. + + Args: + model_info: The modelinfo for the current model. + class_count: The number of classes. + saved_model_dir: Directory in which to save exported model and variables. + """ + # The SavedModel should hold the eval graph. + sess, _, _, _, _ = build_eval_session(model_info, class_count) + graph = sess.graph + with graph.as_default(): + input_tensor = model_info['resized_input_tensor_name'] + in_image = sess.graph.get_tensor_by_name(input_tensor) + inputs = {'image': tf.saved_model.utils.build_tensor_info(in_image)} + + out_classes = sess.graph.get_tensor_by_name('final_result:0') + outputs = { + 'prediction': tf.saved_model.utils.build_tensor_info(out_classes) + } + + signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME) + + legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') + + # Save out the SavedModel. + builder = tf.saved_model.builder.SavedModelBuilder(saved_model_dir) + builder.add_meta_graph_and_variables( + sess, [tf.saved_model.tag_constants.SERVING], + signature_def_map={ + tf.saved_model.signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY: + signature + }, + legacy_init_op=legacy_init_op) + builder.save() + + def main(_): # Needed to make sure the logging output is visible. # See https://github.com/tensorflow/tensorflow/issues/3047 @@ -1018,11 +1132,6 @@ def main(_): tf.logging.error('Did not recognize architecture flag') return -1 - # Set up the pre-trained graph. - maybe_download_and_extract(model_info['data_url']) - graph, bottleneck_tensor, resized_image_tensor = ( - create_model_graph(model_info)) - # Look at the folder structure, and create lists of all the images. image_lists = create_image_lists(FLAGS.image_dir, FLAGS.testing_percentage, FLAGS.validation_percentage) @@ -1041,6 +1150,19 @@ def main(_): FLAGS.flip_left_right, FLAGS.random_crop, FLAGS.random_scale, FLAGS.random_brightness) + # Set up the pre-trained graph. + maybe_download_and_extract(model_info['data_url']) + graph, bottleneck_tensor, resized_image_tensor = ( + create_model_graph(model_info)) + + # Add the new layer that we'll be training. + with graph.as_default(): + (train_step, cross_entropy, bottleneck_input, + ground_truth_input, final_tensor) = add_final_retrain_ops( + class_count, FLAGS.final_tensor_name, bottleneck_tensor, + model_info['bottleneck_tensor_size'], model_info['quantize_layer'], + True) + with tf.Session(graph=graph) as sess: # Set up the image decoding sub-graph. jpeg_data_tensor, decoded_image_tensor = add_jpeg_decoding( @@ -1064,15 +1186,8 @@ def main(_): decoded_image_tensor, resized_image_tensor, bottleneck_tensor, FLAGS.architecture) - # Add the new layer that we'll be training. - (train_step, cross_entropy, bottleneck_input, ground_truth_input, - final_tensor) = add_final_training_ops( - len(image_lists.keys()), FLAGS.final_tensor_name, bottleneck_tensor, - model_info['bottleneck_tensor_size'], model_info['quantize_layer']) - # Create the operations we need to evaluate the accuracy of our new layer. - evaluation_step, prediction = add_evaluation_step( - final_tensor, ground_truth_input) + evaluation_step, _ = add_evaluation_step(final_tensor, ground_truth_input) # Merge all the summaries and write them out to the summaries_dir merged = tf.summary.merge_all() @@ -1082,6 +1197,10 @@ def main(_): validation_writer = tf.summary.FileWriter( FLAGS.summaries_dir + '/validation') + # Create a train saver that is used to restore values into an eval graph + # when exporting models. + train_saver = tf.train.Saver() + # Set up all our weights to their initial default values. init = tf.global_variables_initializer() sess.run(init) @@ -1122,6 +1241,9 @@ def main(_): (datetime.now(), i, train_accuracy * 100)) tf.logging.info('%s: Step %d: Cross entropy = %f' % (datetime.now(), i, cross_entropy_value)) + # TODO(suharshs): Make this use an eval graph, to avoid quantization + # moving averages being updated by the validation set, though in + # practice this makes a negligable difference. validation_bottlenecks, validation_ground_truth, _ = ( get_random_cached_bottlenecks( sess, image_lists, FLAGS.validation_batch_size, 'validation', @@ -1144,41 +1266,33 @@ def main(_): if (intermediate_frequency > 0 and (i % intermediate_frequency == 0) and i > 0): + # If we want to do an intermediate save, save a checkpoint of the train + # graph, to restore into the eval graph. + train_saver.save(sess, CHECKPOINT_NAME) intermediate_file_name = (FLAGS.intermediate_output_graphs_dir + 'intermediate_' + str(i) + '.pb') tf.logging.info('Save intermediate result to : ' + intermediate_file_name) - save_graph_to_file(sess, graph, intermediate_file_name) + save_graph_to_file(graph, intermediate_file_name, model_info, + class_count) + + # After training is complete, force one last save of the train checkpoint. + train_saver.save(sess, CHECKPOINT_NAME) # We've completed all our training, so run a final test evaluation on # some new images we haven't used before. - test_bottlenecks, test_ground_truth, test_filenames = ( - get_random_cached_bottlenecks( - sess, image_lists, FLAGS.test_batch_size, 'testing', - FLAGS.bottleneck_dir, FLAGS.image_dir, jpeg_data_tensor, - decoded_image_tensor, resized_image_tensor, bottleneck_tensor, - FLAGS.architecture)) - test_accuracy, predictions = sess.run( - [evaluation_step, prediction], - feed_dict={bottleneck_input: test_bottlenecks, - ground_truth_input: test_ground_truth}) - tf.logging.info('Final test accuracy = %.1f%% (N=%d)' % - (test_accuracy * 100, len(test_bottlenecks))) - - if FLAGS.print_misclassified_test_images: - tf.logging.info('=== MISCLASSIFIED TEST IMAGES ===') - for i, test_filename in enumerate(test_filenames): - if predictions[i] != test_ground_truth[i]: - tf.logging.info('%70s %s' % - (test_filename, - list(image_lists.keys())[predictions[i]])) + run_final_eval(sess, model_info, class_count, image_lists, jpeg_data_tensor, + decoded_image_tensor, resized_image_tensor, + bottleneck_tensor) # Write out the trained graph and labels with the weights stored as # constants. - save_graph_to_file(sess, graph, FLAGS.output_graph) + save_graph_to_file(graph, FLAGS.output_graph, model_info, class_count) with gfile.FastGFile(FLAGS.output_labels, 'w') as f: f.write('\n'.join(image_lists.keys()) + '\n') + export_model(model_info, class_count, FLAGS.saved_model_dir) + if __name__ == '__main__': parser = argparse.ArgumentParser() @@ -1358,9 +1472,15 @@ if __name__ == '__main__': form 'mobilenet__[_quantized]'. For example, 'mobilenet_1.0_224' will pick a model that is 17 MB in size and takes 224 pixel input images, while 'mobilenet_0.25_128_quantized' will choose a much - less accurate, but smaller and faster network that's 920 KB on disk and - takes 128x128 images. See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html + smaller and less accurate model, taking 128x128 images, and instrumented + for eventual quantization via TensorFlow Lite. + See https://research.googleblog.com/2017/06/mobilenets-open-source-models-for.html for more information on Mobilenet.\ """) + parser.add_argument( + '--saved_model_dir', + type=str, + default='/tmp/saved_models/1/', + help='Where to save the exported graph.') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/examples/image_retraining/retrain_test.py b/tensorflow/examples/image_retraining/retrain_test.py index 8b8dd45fd72e3d29bdb7f6291cc53b912adf3644..fb7324c58ac1be60baad840207f31a61ec6182be 100644 --- a/tensorflow/examples/image_retraining/retrain_test.py +++ b/tensorflow/examples/image_retraining/retrain_test.py @@ -67,22 +67,52 @@ class ImageRetrainingTest(test_util.TensorFlowTestCase): self.assertIsNotNone(sess.graph.get_tensor_by_name('DistortResult:0')) @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01) - def testAddFinalTrainingOps(self, flags_mock): + def testAddFinalRetrainOps(self, flags_mock): with tf.Graph().as_default(): with tf.Session() as sess: bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') - # Test creating final training op with quantization - retrain.add_final_training_ops(5, 'final', bottleneck, 1024, False) + # Test creating final training op with quantization. + retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, False, + False) self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) @tf.test.mock.patch.object(retrain, 'FLAGS', learning_rate=0.01) - def testAddFinalTrainingOpsQuantized(self, flags_mock): - with tf.Graph().as_default(): + def testAddFinalRetrainOpsQuantized(self, flags_mock): + # Ensure that the training and eval graph for quantized models are correctly + # created. + with tf.Graph().as_default() as g: + with tf.Session() as sess: + bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') + # Test creating final training op with quantization, set is_training to + # true. + retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, True) + self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) + found_fake_quant = 0 + for op in g.get_operations(): + if op.type == 'FakeQuantWithMinMaxVars': + found_fake_quant += 1 + # Ensure that the inputs of each FakeQuant operations has 2 Assign + # operations in the training graph (Assign[Min,Max]Last, + # Assign[Min,Max]Ema) + self.assertEqual(2, + len([i for i in op.inputs if 'Assign' in i.name])) + self.assertEqual(found_fake_quant, 2) + with tf.Graph().as_default() as g: with tf.Session() as sess: bottleneck = tf.placeholder(tf.float32, [1, 1024], name='bottleneck') - # Test creating final training op with quantization - retrain.add_final_training_ops(5, 'final', bottleneck, 1024, True) + # Test creating final training op with quantization, set is_training to + # false. + retrain.add_final_retrain_ops(5, 'final', bottleneck, 1024, True, False) self.assertIsNotNone(sess.graph.get_tensor_by_name('final:0')) + found_fake_quant = 0 + for op in g.get_operations(): + if op.type == 'FakeQuantWithMinMaxVars': + found_fake_quant += 1 + for i in op.inputs: + # Ensure that no operations are Assign operation since this is the + # evaluation graph. + self.assertTrue('Assign' not in i.name) + self.assertEqual(found_fake_quant, 2) def testAddEvaluationStep(self): with tf.Graph().as_default(): diff --git a/tensorflow/examples/speech_commands/label_wav_dir.py b/tensorflow/examples/speech_commands/label_wav_dir.py new file mode 100644 index 0000000000000000000000000000000000000000..a34db512dda86be138e07a4ffaa1963fe00a5cea --- /dev/null +++ b/tensorflow/examples/speech_commands/label_wav_dir.py @@ -0,0 +1,136 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Runs a trained audio graph against WAVE files and reports the results. + +The model, labels and .wav files specified in the arguments will be loaded, and +then the predictions from running the model against the audio data will be +printed to the console. This is a useful script for sanity checking trained +models, and as an example of how to use an audio model from Python. + +Here's an example of running it: + +python tensorflow/examples/speech_commands/label_wav_dir.py \ +--graph=/tmp/my_frozen_graph.pb \ +--labels=/tmp/speech_commands_train/conv_labels.txt \ +--wav_dir=/tmp/speech_dataset/left + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import glob +import sys + +import tensorflow as tf + +# pylint: disable=unused-import +from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio +# pylint: enable=unused-import + +FLAGS = None + + +def load_graph(filename): + """Unpersists graph from file as default graph.""" + with tf.gfile.FastGFile(filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + + +def load_labels(filename): + """Read in labels, one label per line.""" + return [line.rstrip() for line in tf.gfile.GFile(filename)] + + +def run_graph(wav_dir, labels, input_layer_name, output_layer_name, + num_top_predictions): + """Runs the audio data through the graph and prints predictions.""" + with tf.Session() as sess: + # Feed the audio data as input to the graph. + # predictions will contain a two-dimensional array, where one + # dimension represents the input image count, and the other has + # predictions per class + for wav_path in glob.glob(wav_dir + '/*.wav'): + if not wav_path or not tf.gfile.Exists(wav_path): + tf.logging.fatal('Audio file does not exist %s', wav_path) + + with open(wav_path, 'rb') as wav_file: + wav_data = wav_file.read() + + softmax_tensor = sess.graph.get_tensor_by_name(output_layer_name) + predictions, = sess.run(softmax_tensor, {input_layer_name: wav_data}) + + # Sort to show labels in order of confidence + print('\n%s' % (wav_path.split('/')[-1])) + top_k = predictions.argsort()[-num_top_predictions:][::-1] + for node_id in top_k: + human_string = labels[node_id] + score = predictions[node_id] + print('%s (score = %.5f)' % (human_string, score)) + + return 0 + + +def label_wav(wav_dir, labels, graph, input_name, output_name, how_many_labels): + """Loads the model and labels, and runs the inference to print predictions.""" + if not labels or not tf.gfile.Exists(labels): + tf.logging.fatal('Labels file does not exist %s', labels) + + if not graph or not tf.gfile.Exists(graph): + tf.logging.fatal('Graph file does not exist %s', graph) + + labels_list = load_labels(labels) + + # load graph, which is stored in the default session + load_graph(graph) + + run_graph(wav_dir, labels_list, input_name, output_name, how_many_labels) + + +def main(_): + """Entry point for script, converts flags to arguments.""" + label_wav(FLAGS.wav_dir, FLAGS.labels, FLAGS.graph, FLAGS.input_name, + FLAGS.output_name, FLAGS.how_many_labels) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--wav_dir', type=str, default='', help='Audio file to be identified.') + parser.add_argument( + '--graph', type=str, default='', help='Model to use for identification.') + parser.add_argument( + '--labels', type=str, default='', help='Path to file containing labels.') + parser.add_argument( + '--input_name', + type=str, + default='wav_data:0', + help='Name of WAVE data input node in model.') + parser.add_argument( + '--output_name', + type=str, + default='labels_softmax:0', + help='Name of node outputting a prediction in the model.') + parser.add_argument( + '--how_many_labels', + type=int, + default=3, + help='Number of results to show.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 13f38dfb32a476477d306093bad6b56e1744a640..d9e684a661f2690c9352baec0649fbf42fc79255 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -278,73 +278,110 @@ func FakeQuantWithMinMaxVarsPerChannelGradient(scope *Scope, gradients tf.Output return op.Output(0), op.Output(1), op.Output(2) } -// Partitions `data` into `num_partitions` tensors using indices from `partitions`. -// -// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` -// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` -// are placed in `outputs[i]` in lexicographic order of `js`, and the first -// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. -// In detail, -// -// ```python -// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] +// FakeQuantWithMinMaxVarsPerChannelAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannel. +type FakeQuantWithMinMaxVarsPerChannelAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsPerChannelNumBits sets the optional num_bits attribute to value. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsPerChannelNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsPerChannelNarrowRange sets the optional narrow_range attribute to value. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsPerChannelNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, // -// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) -// ``` +// `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` +// to 'outputs' tensor of same shape as `inputs`. // -// `data.shape` must start with `partitions.shape`. +// `[min; max]` define the clamping range for the `inputs` data. +// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` +// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and +// then de-quantized and output as floats in `[min; max]` interval. +// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. // -// For example: +// This operation has a gradient and thus allows for training `min` and `max` +// values. +func FakeQuantWithMinMaxVarsPerChannel(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelAttr) (outputs tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "FakeQuantWithMinMaxVarsPerChannel", + Input: []tf.Input{ + inputs, min, max, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. +type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) + +// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. // -// ```python -// # Scalar partitions. -// partitions = 1 -// num_partitions = 2 -// data = [10, 20] -// outputs[0] = [] # Empty with shape [0, 2] -// outputs[1] = [[10, 20]] +// value: The bitwidth of the quantization; between 2 and 8, inclusive. +// If not specified, defaults to 8 +func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["num_bits"] = value + } +} + +// FakeQuantWithMinMaxVarsGradientNarrowRange sets the optional narrow_range attribute to value. // -// # Vector partitions. -// partitions = [0, 0, 1, 1, 0] -// num_partitions = 2 -// data = [10, 20, 30, 40, 50] -// outputs[0] = [10, 20, 50] -// outputs[1] = [30, 40] -// ``` +// value: Whether to quantize into 2^num_bits - 1 distinct values. +// If not specified, defaults to false +func FakeQuantWithMinMaxVarsGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsGradientAttr { + return func(m optionalAttr) { + m["narrow_range"] = value + } +} + +// Compute gradients for a FakeQuantWithMinMaxVars operation. // -// See `dynamic_stitch` for an example on how to merge partitions back. +// Arguments: +// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. +// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. +// min, max: Quantization interval, scalar floats. // -//
    -// -//
    // -// Arguments: // -// partitions: Any shape. Indices in the range `[0, num_partitions)`. -// num_partitions: The number of partitions to output. -func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { +// Returns Backpropagated gradients w.r.t. inputs: +// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: +// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: +// `sum(gradients * (inputs > max))`. +func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{"num_partitions": num_partitions} + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } opspec := tf.OpSpec{ - Type: "DynamicPartition", + Type: "FakeQuantWithMinMaxVarsGradient", Input: []tf.Input{ - data, partitions, + gradients, inputs, min, max, }, Attrs: attrs, } op := scope.AddOperation(opspec) - if scope.Err() != nil { - return - } - var idx int - var err error - if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { - scope.UpdateErr("DynamicPartition", err) - return - } - return outputs + return op.Output(0), op.Output(1), op.Output(2) } // MutableHashTableOfTensorsV2Attr is an optional argument to MutableHashTableOfTensorsV2. @@ -1644,61 +1681,6 @@ func Igammac(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) { return op.Output(0) } -// FakeQuantWithMinMaxVarsGradientAttr is an optional argument to FakeQuantWithMinMaxVarsGradient. -type FakeQuantWithMinMaxVarsGradientAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsGradientNumBits sets the optional num_bits attribute to value. -// -// value: The bitwidth of the quantization; between 2 and 8, inclusive. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsGradientNumBits(value int64) FakeQuantWithMinMaxVarsGradientAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsGradientNarrowRange sets the optional narrow_range attribute to value. -// -// value: Whether to quantize into 2^num_bits - 1 distinct values. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsGradientNarrowRange(value bool) FakeQuantWithMinMaxVarsGradientAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Compute gradients for a FakeQuantWithMinMaxVars operation. -// -// Arguments: -// gradients: Backpropagated gradients above the FakeQuantWithMinMaxVars operation. -// inputs: Values passed as inputs to the FakeQuantWithMinMaxVars operation. -// min, max: Quantization interval, scalar floats. -// -// -// -// Returns Backpropagated gradients w.r.t. inputs: -// `gradients * (inputs >= min && inputs <= max)`.Backpropagated gradients w.r.t. min parameter: -// `sum(gradients * (inputs < min))`.Backpropagated gradients w.r.t. max parameter: -// `sum(gradients * (inputs > max))`. -func FakeQuantWithMinMaxVarsGradient(scope *Scope, gradients tf.Output, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsGradientAttr) (backprops_wrt_input tf.Output, backprop_wrt_min tf.Output, backprop_wrt_max tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsGradient", - Input: []tf.Input{ - gradients, inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0), op.Output(1), op.Output(2) -} - // LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler. type LogUniformCandidateSamplerAttr func(optionalAttr) @@ -2429,26 +2411,6 @@ func ReaderNumWorkUnitsCompletedV2(scope *Scope, reader_handle tf.Output) (units return op.Output(0) } -// Returns x / y element-wise for real types. -// -// If `x` and `y` are reals, this will return the floating-point division. -// -// *NOTE*: `Div` supports broadcasting. More about broadcasting -// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) -func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "RealDiv", - Input: []tf.Input{ - x, y, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // Computes the log of the absolute value of `Gamma(x)` element-wise. func Lgamma(scope *Scope, x tf.Output) (y tf.Output) { if scope.Err() != nil { @@ -4456,34 +4418,6 @@ func MaxPoolGradGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output return op.Output(0) } -// Fast Fourier transform. -// -// Computes the 1-dimensional discrete Fourier transform over the inner-most -// dimension of `input`. -// -// Arguments: -// input: A complex64 tensor. -// -// Returns A complex64 tensor of the same shape as `input`. The inner-most -// dimension of `input` is replaced with its 1D Fourier transform. -// -// @compatibility(numpy) -// Equivalent to np.fft.fft -// @end_compatibility -func FFT(scope *Scope, input tf.Output) (output tf.Output) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "FFT", - Input: []tf.Input{ - input, - }, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // MaxPoolAttr is an optional argument to MaxPool. type MaxPoolAttr func(optionalAttr) @@ -4597,47 +4531,6 @@ func MaxPoolGradWithArgmax(scope *Scope, input tf.Output, grad tf.Output, argmax return op.Output(0) } -// CriticalSectionOpAttr is an optional argument to CriticalSectionOp. -type CriticalSectionOpAttr func(optionalAttr) - -// CriticalSectionOpContainer sets the optional container attribute to value. -// -// value: the container this critical section is placed in. -// If not specified, defaults to "" -func CriticalSectionOpContainer(value string) CriticalSectionOpAttr { - return func(m optionalAttr) { - m["container"] = value - } -} - -// CriticalSectionOpSharedName sets the optional shared_name attribute to value. -// -// value: the name by which this critical section is referred to. -// If not specified, defaults to "" -func CriticalSectionOpSharedName(value string) CriticalSectionOpAttr { - return func(m optionalAttr) { - m["shared_name"] = value - } -} - -// Creates a handle to a CriticalSection resource. -func CriticalSectionOp(scope *Scope, optional ...CriticalSectionOpAttr) (resource tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "CriticalSectionOp", - - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // FakeQuantWithMinMaxArgsGradientAttr is an optional argument to FakeQuantWithMinMaxArgsGradient. type FakeQuantWithMinMaxArgsGradientAttr func(optionalAttr) @@ -5005,6 +4898,78 @@ func DepthwiseConv2dNative(scope *Scope, input tf.Output, filter tf.Output, stri return op.Output(0) } +// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. +type MaxPoolGradV2Attr func(optionalAttr) + +// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients w.r.t. the output of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients w.r.t. the input to `max_pool`. +func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradV2", + Input: []tf.Input{ + orig_input, orig_output, grad, ksize, strides, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Restore a reader to a previously saved state. +// +// Not all Readers support being restored, so this can produce an +// Unimplemented error. +// +// Arguments: +// reader_handle: Handle to a Reader. +// state: Result of a ReaderSerializeState of a Reader with type +// matching reader_handle. +// +// Returns the created operation. +func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "ReaderRestoreStateV2", + Input: []tf.Input{ + reader_handle, state, + }, + } + return scope.AddOperation(opspec) +} + // TensorArrayGatherV3Attr is an optional argument to TensorArrayGatherV3. type TensorArrayGatherV3Attr func(optionalAttr) @@ -5823,78 +5788,6 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra return op.Output(0) } -// MaxPoolGradV2Attr is an optional argument to MaxPoolGradV2. -type MaxPoolGradV2Attr func(optionalAttr) - -// MaxPoolGradV2DataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradV2DataFormat(value string) MaxPoolGradV2Attr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients w.r.t. the output of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients w.r.t. the input to `max_pool`. -func MaxPoolGradV2(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize tf.Output, strides tf.Output, padding string, optional ...MaxPoolGradV2Attr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradV2", - Input: []tf.Input{ - orig_input, orig_output, grad, ksize, strides, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - -// Restore a reader to a previously saved state. -// -// Not all Readers support being restored, so this can produce an -// Unimplemented error. -// -// Arguments: -// reader_handle: Handle to a Reader. -// state: Result of a ReaderSerializeState of a Reader with type -// matching reader_handle. -// -// Returns the created operation. -func ReaderRestoreStateV2(scope *Scope, reader_handle tf.Output, state tf.Output) (o *tf.Operation) { - if scope.Err() != nil { - return - } - opspec := tf.OpSpec{ - Type: "ReaderRestoreStateV2", - Input: []tf.Input{ - reader_handle, state, - }, - } - return scope.AddOperation(opspec) -} - // ResourceSparseApplyFtrlV2Attr is an optional argument to ResourceSparseApplyFtrlV2. type ResourceSparseApplyFtrlV2Attr func(optionalAttr) @@ -8776,40 +8669,109 @@ func SparseReduceSumKeepDims(value bool) SparseReduceSumAttr { // `tf.reduce_sum()`. In particular, this Op also returns a dense `Tensor` // instead of a sparse one. // -// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless -// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in -// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained -// with length 1. +// Reduces `sp_input` along the dimensions given in `reduction_axes`. Unless +// `keep_dims` is true, the rank of the tensor is reduced by 1 for each entry in +// `reduction_axes`. If `keep_dims` is true, the reduced dimensions are retained +// with length 1. +// +// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor +// with a single element is returned. Additionally, the axes can be negative, +// which are interpreted according to the indexing rules in Python. +// +// Arguments: +// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a +// SparseTensor, possibly not in canonical ordering. +// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. +// input_shape: 1-D. Shape of the input SparseTensor. +// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. +// +// Returns `R-K`-D. The reduced Tensor. +func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "SparseReduceSum", + Input: []tf.Input{ + input_indices, input_values, input_shape, reduction_axes, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + +// Partitions `data` into `num_partitions` tensors using indices from `partitions`. +// +// For each index tuple `js` of size `partitions.ndim`, the slice `data[js, ...]` +// becomes part of `outputs[partitions[js]]`. The slices with `partitions[js] = i` +// are placed in `outputs[i]` in lexicographic order of `js`, and the first +// dimension of `outputs[i]` is the number of entries in `partitions` equal to `i`. +// In detail, +// +// ```python +// outputs[i].shape = [sum(partitions == i)] + data.shape[partitions.ndim:] +// +// outputs[i] = pack([data[js, ...] for js if partitions[js] == i]) +// ``` +// +// `data.shape` must start with `partitions.shape`. +// +// For example: +// +// ```python +// # Scalar partitions. +// partitions = 1 +// num_partitions = 2 +// data = [10, 20] +// outputs[0] = [] # Empty with shape [0, 2] +// outputs[1] = [[10, 20]] +// +// # Vector partitions. +// partitions = [0, 0, 1, 1, 0] +// num_partitions = 2 +// data = [10, 20, 30, 40, 50] +// outputs[0] = [10, 20, 50] +// outputs[1] = [30, 40] +// ``` +// +// See `dynamic_stitch` for an example on how to merge partitions back. // -// If `reduction_axes` has no entries, all dimensions are reduced, and a tensor -// with a single element is returned. Additionally, the axes can be negative, -// which are interpreted according to the indexing rules in Python. +//
    +// +//
    // // Arguments: -// input_indices: 2-D. `N x R` matrix with the indices of non-empty values in a -// SparseTensor, possibly not in canonical ordering. -// input_values: 1-D. `N` non-empty values corresponding to `input_indices`. -// input_shape: 1-D. Shape of the input SparseTensor. -// reduction_axes: 1-D. Length-`K` vector containing the reduction axes. // -// Returns `R-K`-D. The reduced Tensor. -func SparseReduceSum(scope *Scope, input_indices tf.Output, input_values tf.Output, input_shape tf.Output, reduction_axes tf.Output, optional ...SparseReduceSumAttr) (output tf.Output) { +// partitions: Any shape. Indices in the range `[0, num_partitions)`. +// num_partitions: The number of partitions to output. +func DynamicPartition(scope *Scope, data tf.Output, partitions tf.Output, num_partitions int64) (outputs []tf.Output) { if scope.Err() != nil { return } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } + attrs := map[string]interface{}{"num_partitions": num_partitions} opspec := tf.OpSpec{ - Type: "SparseReduceSum", + Type: "DynamicPartition", Input: []tf.Input{ - input_indices, input_values, input_shape, reduction_axes, + data, partitions, }, Attrs: attrs, } op := scope.AddOperation(opspec) - return op.Output(0) + if scope.Err() != nil { + return + } + var idx int + var err error + if outputs, idx, err = makeOutputList(op, idx, "outputs"); err != nil { + scope.UpdateErr("DynamicPartition", err) + return + } + return outputs } // ResourceApplyAdagradAttr is an optional argument to ResourceApplyAdagrad. @@ -9301,6 +9263,34 @@ func SparseSoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, label return op.Output(0), op.Output(1) } +// Fast Fourier transform. +// +// Computes the 1-dimensional discrete Fourier transform over the inner-most +// dimension of `input`. +// +// Arguments: +// input: A complex64 tensor. +// +// Returns A complex64 tensor of the same shape as `input`. The inner-most +// dimension of `input` is replaced with its 1D Fourier transform. +// +// @compatibility(numpy) +// Equivalent to np.fft.fft +// @end_compatibility +func FFT(scope *Scope, input tf.Output) (output tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "FFT", + Input: []tf.Input{ + input, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // ResourceSparseApplyAdagradDAAttr is an optional argument to ResourceSparseApplyAdagradDA. type ResourceSparseApplyAdagradDAAttr func(optionalAttr) @@ -11437,6 +11427,54 @@ func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf. return scope.AddOperation(opspec) } +// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. +type MaxPoolGradGradAttr func(optionalAttr) + +// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. +// +// value: Specify the data format of the input and output data. With the +// default format "NHWC", the data is stored in the order of: +// [batch, in_height, in_width, in_channels]. +// Alternatively, the format could be "NCHW", the data storage order of: +// [batch, in_channels, in_height, in_width]. +// If not specified, defaults to "NHWC" +func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { + return func(m optionalAttr) { + m["data_format"] = value + } +} + +// Computes second-order gradients of the maxpooling function. +// +// Arguments: +// orig_input: The original input tensor. +// orig_output: The original output tensor. +// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. +// ksize: The size of the window for each dimension of the input tensor. +// strides: The stride of the sliding window for each dimension of the +// input tensor. +// padding: The type of padding algorithm to use. +// +// Returns Gradients of gradients w.r.t. the input to `max_pool`. +func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { + if scope.Err() != nil { + return + } + attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} + for _, a := range optional { + a(attrs) + } + opspec := tf.OpSpec{ + Type: "MaxPoolGradGrad", + Input: []tf.Input{ + orig_input, orig_output, grad, + }, + Attrs: attrs, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Returns the truth value of (x >= y) element-wise. // // *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting @@ -14994,54 +15032,6 @@ func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) { return scope.AddOperation(opspec) } -// MaxPoolGradGradAttr is an optional argument to MaxPoolGradGrad. -type MaxPoolGradGradAttr func(optionalAttr) - -// MaxPoolGradGradDataFormat sets the optional data_format attribute to value. -// -// value: Specify the data format of the input and output data. With the -// default format "NHWC", the data is stored in the order of: -// [batch, in_height, in_width, in_channels]. -// Alternatively, the format could be "NCHW", the data storage order of: -// [batch, in_channels, in_height, in_width]. -// If not specified, defaults to "NHWC" -func MaxPoolGradGradDataFormat(value string) MaxPoolGradGradAttr { - return func(m optionalAttr) { - m["data_format"] = value - } -} - -// Computes second-order gradients of the maxpooling function. -// -// Arguments: -// orig_input: The original input tensor. -// orig_output: The original output tensor. -// grad: 4-D. Gradients of gradients w.r.t. the input of `max_pool`. -// ksize: The size of the window for each dimension of the input tensor. -// strides: The stride of the sliding window for each dimension of the -// input tensor. -// padding: The type of padding algorithm to use. -// -// Returns Gradients of gradients w.r.t. the input to `max_pool`. -func MaxPoolGradGrad(scope *Scope, orig_input tf.Output, orig_output tf.Output, grad tf.Output, ksize []int64, strides []int64, padding string, optional ...MaxPoolGradGradAttr) (output tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"ksize": ksize, "strides": strides, "padding": padding} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "MaxPoolGradGrad", - Input: []tf.Input{ - orig_input, orig_output, grad, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // RandomUniformIntAttr is an optional argument to RandomUniformInt. type RandomUniformIntAttr func(optionalAttr) @@ -15312,57 +15302,6 @@ func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional return op.Output(0) } -// FakeQuantWithMinMaxVarsPerChannelAttr is an optional argument to FakeQuantWithMinMaxVarsPerChannel. -type FakeQuantWithMinMaxVarsPerChannelAttr func(optionalAttr) - -// FakeQuantWithMinMaxVarsPerChannelNumBits sets the optional num_bits attribute to value. -// If not specified, defaults to 8 -func FakeQuantWithMinMaxVarsPerChannelNumBits(value int64) FakeQuantWithMinMaxVarsPerChannelAttr { - return func(m optionalAttr) { - m["num_bits"] = value - } -} - -// FakeQuantWithMinMaxVarsPerChannelNarrowRange sets the optional narrow_range attribute to value. -// If not specified, defaults to false -func FakeQuantWithMinMaxVarsPerChannelNarrowRange(value bool) FakeQuantWithMinMaxVarsPerChannelAttr { - return func(m optionalAttr) { - m["narrow_range"] = value - } -} - -// Fake-quantize the 'inputs' tensor of type float and one of the shapes: `[d]`, -// -// `[b, d]` `[b, h, w, d]` via per-channel floats `min` and `max` of shape `[d]` -// to 'outputs' tensor of same shape as `inputs`. -// -// `[min; max]` define the clamping range for the `inputs` data. -// `inputs` values are quantized into the quantization range (`[0; 2^num_bits - 1]` -// when `narrow_range` is false and `[1; 2^num_bits - 1]` when it is true) and -// then de-quantized and output as floats in `[min; max]` interval. -// `num_bits` is the bitwidth of the quantization; between 2 and 8, inclusive. -// -// This operation has a gradient and thus allows for training `min` and `max` -// values. -func FakeQuantWithMinMaxVarsPerChannel(scope *Scope, inputs tf.Output, min tf.Output, max tf.Output, optional ...FakeQuantWithMinMaxVarsPerChannelAttr) (outputs tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{} - for _, a := range optional { - a(attrs) - } - opspec := tf.OpSpec{ - Type: "FakeQuantWithMinMaxVarsPerChannel", - Input: []tf.Input{ - inputs, min, max, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // RandomShuffleAttr is an optional argument to RandomShuffle. type RandomShuffleAttr func(optionalAttr) @@ -17760,23 +17699,6 @@ func SoftplusGrad(scope *Scope, gradients tf.Output, features tf.Output) (backpr return op.Output(0) } -// Creates a dataset that contains the unique elements of `input_dataset`. -func UniqueDataset(scope *Scope, input_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { - if scope.Err() != nil { - return - } - attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes} - opspec := tf.OpSpec{ - Type: "UniqueDataset", - Input: []tf.Input{ - input_dataset, - }, - Attrs: attrs, - } - op := scope.AddOperation(opspec) - return op.Output(0) -} - // SelfAdjointEigV2Attr is an optional argument to SelfAdjointEigV2. type SelfAdjointEigV2Attr func(optionalAttr) @@ -20021,6 +19943,26 @@ func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, return op.Output(0) } +// Returns x / y element-wise for real types. +// +// If `x` and `y` are reals, this will return the floating-point division. +// +// *NOTE*: `Div` supports broadcasting. More about broadcasting +// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) +func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) { + if scope.Err() != nil { + return + } + opspec := tf.OpSpec{ + Type: "RealDiv", + Input: []tf.Input{ + x, y, + }, + } + op := scope.AddOperation(opspec) + return op.Output(0) +} + // Creates a dataset that concatenates `input_dataset` with `another_dataset`. func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) { if scope.Err() != nil { diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml index 99add510696c852b224b40fbafd03620f2606cd3..d35bb4111271c11839a160517dc9695ead5b46e9 100644 --- a/tensorflow/java/maven/libtensorflow/pom.xml +++ b/tensorflow/java/maven/libtensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc0 + 1.6.0-rc1 ../ libtensorflow diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml index 7bb9879f6838c71f2132dd1e331fdb79ccde8527..d9ba1bbbfb91170257f64a56f47c6c980e8a9570 100644 --- a/tensorflow/java/maven/libtensorflow_jni/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc0 + 1.6.0-rc1 ../ libtensorflow_jni diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml index 268e1bae1fe49b7270b37e1a625f3531a42f556b..f6f532c2c10d0a4dad9fc2d7750ea708652000b1 100644 --- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml +++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc0 + 1.6.0-rc1 ../ libtensorflow_jni_gpu diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml index 6a3abcbc1143598a9405fdd9b7ebf83e1f8196d6..0a6b3d23d7d37515cf275e6a46842e32ada4fee1 100644 --- a/tensorflow/java/maven/pom.xml +++ b/tensorflow/java/maven/pom.xml @@ -6,7 +6,7 @@ 4.0.0 org.tensorflow parentpom - 1.6.0-rc0 + 1.6.0-rc1 pom https://www.tensorflow.org diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml index 54a4fd577a0e3242d4b7f89586b3283f11fca856..1d8e8723731f959c8142f0648fc805593d7beac8 100644 --- a/tensorflow/java/maven/proto/pom.xml +++ b/tensorflow/java/maven/proto/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc0 + 1.6.0-rc1 ../ proto diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml index 76e0fecae4a5134625d812379d7c9029f38d0324..5c1b55085c5df1ec473a3f4e0bf750b236cfc264 100644 --- a/tensorflow/java/maven/tensorflow/pom.xml +++ b/tensorflow/java/maven/tensorflow/pom.xml @@ -6,7 +6,7 @@ org.tensorflow parentpom - 1.6.0-rc0 + 1.6.0-rc1 ../ tensorflow diff --git a/tensorflow/java/src/main/java/org/tensorflow/package-info.java b/tensorflow/java/src/main/java/org/tensorflow/package-info.java index dd4859e1b14045e4123e7f15fbaff98e14d0b377..521c5c610c1f775cf9174664f5b786786ce1181d 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/package-info.java +++ b/tensorflow/java/src/main/java/org/tensorflow/package-info.java @@ -35,5 +35,9 @@ limitations under the License. *
  • Graph execution: Using a Session to execute the graphs and find the best label for an * image. * + * + *

    Additional examples can be found in the tensorflow/models + * GitHub repository. */ package org.tensorflow; diff --git a/tensorflow/java/src/main/native/tensor_jni.cc b/tensorflow/java/src/main/native/tensor_jni.cc index 745abec244d1528e918464473e5d3fb19ad5082c..7e3cf4a88aac5acd4721a07c8316d8d124dce001 100644 --- a/tensorflow/java/src/main/native/tensor_jni.cc +++ b/tensorflow/java/src/main/native/tensor_jni.cc @@ -400,7 +400,13 @@ size_t nonScalarTF_STRINGTensorSize(JNIEnv* env, jarray value, int num_dims) { for (jsize i = 0; i < len; ++i) { jarray elem = static_cast( env->GetObjectArrayElement(static_cast(value), i)); + if (elem == nullptr) { + throwException(env, kNullPointerException, + "null entries in provided array"); + return ret; + } ret += nonScalarTF_STRINGTensorSize(env, elem, num_dims - 1); + if (env->ExceptionCheck()) return ret; } return ret; } @@ -421,8 +427,8 @@ void fillNonScalarTF_STRINGTensorData(JNIEnv* env, jarray value, int num_dims, for (jsize i = 0; i < len; ++i) { jarray elem = static_cast( env->GetObjectArrayElement(static_cast(value), i)); - if (TF_GetCode(status) != TF_OK) return; fillNonScalarTF_STRINGTensorData(env, elem, num_dims - 1, writer, status); + if (TF_GetCode(status) != TF_OK) return; } } } // namespace @@ -444,6 +450,7 @@ JNIEXPORT jlong JNICALL Java_org_tensorflow_Tensor_allocateNonScalarBytes( } const size_t encoded_size = nonScalarTF_STRINGTensorSize(env, value, num_dims); + if (env->ExceptionCheck()) return 0; TF_Tensor* t = TF_AllocateTensor(TF_STRING, dims, num_dims, 8 * num_elements + encoded_size); if (t == nullptr) { diff --git a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java index 6538359d11a95eae698cc5aac8430e74ab1ed74c..1bd00a763ddff2f067183f57cfa80fdcbed84fd2 100644 --- a/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java +++ b/tensorflow/java/src/test/java/org/tensorflow/TensorTest.java @@ -432,7 +432,7 @@ public class TensorTest { try (Tensor t = Tensor.create(vector, Integer.class)) { fail("Tensor.create() should fail because it was given an array of boxed values"); } catch (IllegalArgumentException e) { - // The expected exception + // The expected exception } } @@ -536,4 +536,15 @@ public class TensorTest { assertArrayEquals(matrix, cpy.copyTo(new float[2][3])); } } + + @Test + public void gracefullyFailCreationFromNullArrayForStringTensor() { + // Motivated by: https://github.com/tensorflow/tensorflow/issues/17130 + byte[][] array = new byte[1][]; + try { + Tensors.create(array); + } catch (NullPointerException e) { + // expected. + } + } } diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index f563d32388d63808bb483530c73d7aa669abecd0..b0cb48c80c4d1a1f96fe7dc9ade40002e7a6690a 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -765,6 +765,31 @@ py_library( ], ) +py_library( + name = "smart_cond", + srcs = ["framework/smart_cond.py"], + srcs_version = "PY2AND3", + deps = [ + ":control_flow_ops", + ":tensor_util", + ], +) + +py_test( + name = "smart_cond_test", + size = "small", + srcs = ["framework/smart_cond_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":client_testlib", + ":constant_op", + ":framework_ops", + ":math_ops", + ":session", + ":smart_cond", + ], +) + py_library( name = "sparse_tensor", srcs = ["framework/sparse_tensor.py"], @@ -2518,6 +2543,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":array_ops", + ":checkpointable", ":control_flow_ops", ":dtypes", ":framework_ops", @@ -2851,6 +2877,29 @@ py_library( ], ) +py_library( + name = "checkpointable", + srcs = ["training/checkpointable.py"], + srcs_version = "PY2AND3", + deps = [ + ":dtypes", + ":io_ops_gen", + ":ops", + ":util", + "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "checkpointable_test", + srcs = ["training/checkpointable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpointable", + ":client_testlib", + ], +) + py_test( name = "evaluation_test", size = "small", @@ -4067,6 +4116,7 @@ py_library( ":control_flow_ops", ":framework_for_generated_wrappers", ":platform", + ":smart_cond", ":tensor_util", ":util", ":variable_scope", @@ -4083,8 +4133,6 @@ py_library( "layers/convolutional.py", "layers/core.py", "layers/layers.py", - "layers/maxout.py", - "layers/network.py", "layers/normalization.py", "layers/pooling.py", ], @@ -4137,25 +4185,6 @@ py_test( ], ) -py_test( - name = "layers_network_test", - size = "small", - srcs = ["layers/network_test.py"], - main = "layers/network_test.py", - srcs_version = "PY2AND3", - deps = [ - ":array_ops", - ":client_testlib", - ":framework_for_generated_wrappers", - ":framework_test_lib", - ":layers", - ":layers_base", - ":sparse_ops", - "//tensorflow/python/eager:context", - "//third_party/py/numpy", - ], -) - py_test( name = "layers_core_test", size = "small", @@ -4194,22 +4223,6 @@ py_test( ], ) -py_test( - name = "layers_maxout_test", - size = "small", - srcs = ["layers/maxout_test.py"], - main = "layers/maxout_test.py", - srcs_version = "PY2AND3", - deps = [ - ":client_testlib", - ":framework_for_generated_wrappers", - ":layers", - ":math_ops", - ":nn_ops", - ":random_ops", - ], -) - py_test( name = "layers_utils_test", size = "small", @@ -4605,6 +4618,34 @@ py_test( ], ) +py_library( + name = "graph_placer", + srcs = [ + "grappler/controller.py", + "grappler/graph_placer.py", + "grappler/hierarchical_controller.py", + ], + deps = [ + ":python", + "//third_party/py/numpy", + ], +) + +py_test( + name = "graph_placer_test", + size = "large", + srcs = ["grappler/graph_placer_test.py"], + tags = [ + "grappler", + "no_pip", # graph_placer is not available in pip. + ], + deps = [ + ":client_testlib", + ":graph_placer", + "//tensorflow/python:math_ops", + ], +) + py_test( name = "memory_optimizer_test", size = "medium", diff --git a/tensorflow/python/client/events_writer.i b/tensorflow/python/client/events_writer.i index de030fcb4282912475ed8853bae9d41cde2c085d..c72b76b8fa4a05588841466a836bc189bb64d154 100644 --- a/tensorflow/python/client/events_writer.i +++ b/tensorflow/python/client/events_writer.i @@ -23,6 +23,9 @@ limitations under the License. %nodefaultctor EventsWriter; +%ignore tensorflow::Status::operator=; +%include "tensorflow/core/lib/core/status.h" + %ignoreall %unignore tensorflow; %unignore tensorflow::EventsWriter; diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i index 1fd488e7b6388f7953a279dca8f93ab57a85f63d..f305cd271f98bea697ea8ff15be799d3e80db0bf 100644 --- a/tensorflow/python/client/tf_session.i +++ b/tensorflow/python/client/tf_session.i @@ -719,6 +719,8 @@ def TF_Reset(target, containers=None, config=None): $1 = &types_local; } +%unignore SetRequireShapeInferenceFns; + %include "tensorflow/python/client/tf_session_helper.h" %unignoreall diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py index b71652c980f233ce116ea89544fcb38ad1d816d1..02720a2e985914d3a6774dc6f64d1316890c46bf 100644 --- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py @@ -28,6 +28,7 @@ from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -202,44 +203,45 @@ class FilesystemCacheDatasetTest(test.TestCase): class MemoryCacheDatasetTest(test.TestCase): def testCacheDatasetPassthrough(self): - repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) - dataset = dataset_ops.Dataset.range(3).flat_map( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) + with ops.device("cpu:0"): + repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) + dataset = dataset_ops.Dataset.range(3).flat_map( + lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) - cached_dataset = dataset.cache().repeat(2) - uncached_dataset = dataset.repeat(2) + cached_dataset = dataset.cache().repeat(2) + uncached_dataset = dataset.repeat(2) - # Needs to be initializable to capture the variable. - cached_iterator = cached_dataset.make_initializable_iterator() - cached_next = cached_iterator.get_next() - uncached_iterator = uncached_dataset.make_initializable_iterator() - uncached_next = uncached_iterator.get_next() + # Needs to be initializable to capture the variable. + cached_iterator = cached_dataset.make_initializable_iterator() + cached_next = cached_iterator.get_next() + uncached_iterator = uncached_dataset.make_initializable_iterator() + uncached_next = uncached_iterator.get_next() - with self.test_session() as sess: + with self.test_session() as sess: - sess.run(repeat_count.initializer) - sess.run(cached_iterator.initializer) - sess.run(uncached_iterator.initializer) + sess.run(repeat_count.initializer) + sess.run(cached_iterator.initializer) + sess.run(uncached_iterator.initializer) - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) - self.assertEqual(sess.run(uncached_next), i) + for i in range(3): + for _ in range(10): + self.assertEqual(sess.run(cached_next), i) + self.assertEqual(sess.run(uncached_next), i) - sess.run(repeat_count.assign(0)) + sess.run(repeat_count.assign(0)) - # The uncached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(uncached_next) + # The uncached iterator should now be empty. + with self.assertRaises(errors.OutOfRangeError): + sess.run(uncached_next) - # The cached iterator replays from cache. - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) + # The cached iterator replays from cache. + for i in range(3): + for _ in range(10): + self.assertEqual(sess.run(cached_next), i) - # The cached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(cached_next) + # The cached iterator should now be empty. + with self.assertRaises(errors.OutOfRangeError): + sess.run(cached_next) def testEmptyCacheReading(self): components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py index f129d07b57b96b7869c84467aeb2276c93531ef8..6aabad2f574551cbdc152fe378eb9dc0f5f71995 100644 --- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py +++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py @@ -21,9 +21,12 @@ import threading import numpy as np +from tensorflow.python.client import session from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -302,6 +305,89 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testFromGeneratorStopShort(self): + + def generator(): + yield 0 + yield 1 + yield 2 + + iterator = ( + dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int64).make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + self.assertAllEqual(0, sess.run(get_next)) + self.assertAllEqual(1, sess.run(get_next)) + + def testFromGeneratorDestructorCalled(self): + # Use an `Event` to signal that the generator has been deleted. + event = threading.Event() + + class GeneratorWrapper(object): + + def __iter__(self): + return self + + def next(self): + return self.__next__() + + def __next__(self): + return 42 + + def __del__(self): + event.set() + + iterator = dataset_ops.Dataset.from_generator( + GeneratorWrapper, + output_types=dtypes.int64).take(2).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + + with session.Session() as sess: + sess.run(init_op) + self.assertAllEqual(42, sess.run(get_next)) + self.assertAllEqual(42, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + # Test that `GeneratorWrapper` object is destroyed when the + # iterator terminates (and the generator iterator is deleted). + self.assertTrue(event.is_set()) + + def testGeneratorDatasetFinalizeFunctionCalled(self): + # NOTE(mrry): This test tests the internal `_GeneratorDataset`, + # which affords more control over what the finalize function can do than + # the `Dataset.from_generator()` wrapper. + + # Use an `Event` to signal that the generator has been deleted. + event = threading.Event() + + def finalize_fn(_): + def finalize_py_func(): + event.set() + return 0 + return script_ops.py_func(finalize_py_func, [], [dtypes.int64], + stateful=True) + + dummy = constant_op.constant(37) + iterator = (dataset_ops._GeneratorDataset(dummy, lambda x: x, + lambda x: x, finalize_fn) + .take(2) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + self.assertAllEqual(37, sess.run(get_next)) + self.assertAllEqual(37, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertTrue(event.is_set()) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py index 28cb50c00208f95e64bb11ae80656383b1f41e1e..7dbf7268d74a2a18af551de64ced03daab264799 100644 --- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py @@ -201,6 +201,20 @@ class InterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testEmptyInput(self): + iterator = ( + dataset_ops.Dataset.from_tensor_slices([]) + .repeat(None) + .interleave(dataset_ops.Dataset.from_tensors, cycle_length=2) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py index 04d1abdb254feea1df6f1b8cfc5a512802107224..0791c614fa88700fdf2d0d673e168fc9784731a5 100644 --- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py @@ -602,6 +602,28 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testParallelMapOutOfRangeError(self): + def raising_py_func(i): + if i == 100: + raise StopIteration() + else: + return i + + iterator = ( + dataset_ops.Dataset.range(105) + .map(lambda x: script_ops.py_func(raising_py_func, [x], dtypes.int64), + num_parallel_calls=2) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + with self.test_session() as sess: + sess.run(init_op) + for i in range(100): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + class MapDatasetBenchmark(test.Benchmark): diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py index ae08032e191487c38d73876374b24e8f6eefbc80..1d27b036eb804aa301b916b7ed0b7884f75e1a0f 100644 --- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py @@ -201,9 +201,7 @@ class SequenceDatasetTest(test.TestCase): with self.test_session() as sess: sess.run(init_op) - with self.assertRaisesRegexp( - errors.OutOfRangeError, - "Attempted to repeat an empty dataset infinitely."): + with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index b665443b7acb9eb266b6fcf36a002cfce54875f1..3fb1f8d5479fc461a8d1f509c5eec2d0ed4a44c9 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -331,10 +331,10 @@ class Dataset(object): generator_state = Dataset._GeneratorState(generator) - def get_iterator_id_map_fn(unused_dummy): + def get_iterator_id_fn(unused_dummy): """Creates a unique `iterator_id` for each pass over the dataset. - The "iterator_id" disambiguates between multiple concurrently + The returned `iterator_id` disambiguates between multiple concurrently existing iterators. Args: @@ -347,7 +347,7 @@ class Dataset(object): return script_ops.py_func( generator_state.get_next_id, [], dtypes.int64, stateful=True) - def generator_map_fn(iterator_id_t): + def generator_next_fn(iterator_id_t): """Generates the next element from iterator with ID `iterator_id_t`. We map this function across an infinite repetition of the @@ -363,11 +363,9 @@ class Dataset(object): def generator_py_func(iterator_id): """A `py_func` that will be called to invoke the iterator.""" - try: - values = next(generator_state.get_iterator(iterator_id)) - except StopIteration: - generator_state.iterator_completed(iterator_id) - raise StopIteration("Iteration finished.") + # `next()` raises `StopIteration` when there are no more + # elements remaining to be generated. + values = next(generator_state.get_iterator(iterator_id)) # Use the same _convert function from the py_func() implementation to # convert the returned values to arrays early, so that we can inspect @@ -408,17 +406,31 @@ class Dataset(object): return nest.pack_sequence_as(output_types, flat_values) + def finalize_fn(iterator_id_t): + """Releases host-side state for the iterator with ID `iterator_id_t`.""" + + def finalize_py_func(iterator_id): + generator_state.iterator_completed(iterator_id) + # We return a dummy value so that the `finalize_fn` has a valid + # signature. + # NOTE(mrry): Explicitly create an array of `np.int64` because implicit + # casting in `py_func()` will create an array of `np.int32` on Windows, + # leading to a runtime error. + return np.array(0, dtype=np.int64) + + return script_ops.py_func( + finalize_py_func, [iterator_id_t], dtypes.int64, stateful=True) + # This function associates each traversal of `generator` with a unique # iterator ID. - def flat_map_fn(iterator_id_t): - # First, generate an infinite dataset containing the iterator ID repeated - # forever. - repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) - - # The `generator_map_fn` gets the next element from the iterator with the - # relevant ID, and raises StopIteration when that iterator contains no + def flat_map_fn(dummy_arg): + # The `get_iterator_id_fn` gets a unique ID for the current instance of + # of the generator. + # The `generator_next_fn` gets the next element from the iterator with the + # given ID, and raises StopIteration when that iterator contains no # more elements. - return repeated_id.map(generator_map_fn) + return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn, + finalize_fn) # A single-element dataset that, each time it is evaluated, contains a # freshly-generated and unique (for the returned dataset) int64 @@ -426,7 +438,7 @@ class Dataset(object): # is encapsulated in `generator_state`, and captured in # `get_iterator_id_map_fn`. dummy = 0 - id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) + id_dataset = Dataset.from_tensors(dummy) # A dataset that contains all of the elements generated by a # single iterator created from `generator`, identified by the @@ -1033,6 +1045,196 @@ class SparseTensorSliceDataset(Dataset): return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64) +class _GeneratorDataset(Dataset): + """A `Dataset` that generates elements by invoking a function.""" + + def __init__(self, init_args, init_func, next_func, finalize_func): + """Constructs a `_GeneratorDataset`. + + Args: + init_args: A nested structure representing the arguments to `init_func`. + init_func: A TensorFlow function that will be called on `init_args` each + time a C++ iterator over this dataset is constructed. Returns a nested + structure representing the "state" of the dataset. + next_func: A TensorFlow function that will be called on the result of + `init_func` to produce each element, and that raises `OutOfRangeError` + to terminate iteration. + finalize_func: A TensorFlow function that will be called on the result of + `init_func` immediately before a C++ iterator over this dataset is + destroyed. The return value is ignored. + """ + super(_GeneratorDataset, self).__init__() + # These members will be initialized by `tf_init_func`. + self._state_classes = None + self._state_shapes = None + self._state_types = None + + self._init_args = init_args + + init_args_classes = sparse.get_classes(init_args) + init_args_shapes = nest.pack_sequence_as( + init_args, [t.get_shape() for t in nest.flatten(init_args)]) + init_args_types = nest.pack_sequence_as( + init_args, [t.dtype for t in nest.flatten(init_args)]) + + @function.Defun(*nest.flatten( + sparse.as_dense_types(init_args_types, init_args_classes))) + def tf_init_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + dense_shapes = sparse.as_dense_shapes(init_args_shapes, init_args_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(init_args_classes, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, init_args_types, init_args_shapes, init_args_classes) + if _should_unpack_args(nested_args): + ret = init_func(*nested_args) + else: + ret = init_func(nested_args) + + # If `init_func` returns a list of tensors, `nest.flatten()` and + # `ops.convert_to_tensor()` would conspire to attempt to stack + # those tensors into a single tensor, because the customized + # version of `nest.flatten()` does not recurse into lists. Since + # it is more likely that the list arose from returning the + # result of an operation (such as `tf.py_func()`) that returns a + # list of not-necessarily-stackable tensors, we treat the + # returned value is a `tuple` instead. A user wishing to pack + # the return value into a single tensor can use an explicit + # `tf.stack()` before returning. + if isinstance(ret, list): + ret = tuple(ret) + + # Convert any `SparseTensorValue`s to `SparseTensor`s. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + ]) + + self._state_classes = sparse.get_classes(ret) + self._state_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._state_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors and convert result to tensors. + ret = nest.pack_sequence_as(ret, [ + ops.convert_to_tensor(t) + for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) + ]) + return nest.flatten(ret) + + self._init_func = tf_init_func + self._init_func.add_to_graph(ops.get_default_graph()) + + # These members will be initialized by `tf_next_func`. + self._output_classes = None + self._output_shapes = None + self._output_types = None + + @function.Defun(*nest.flatten( + sparse.as_dense_types(self._state_types, self._state_classes))) + def tf_next_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + dense_shapes = sparse.as_dense_shapes(self._state_shapes, + self._state_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(self._state_classes, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, self._state_types, self._state_shapes, + self._state_classes) + if _should_unpack_args(nested_args): + ret = next_func(*nested_args) + else: + ret = next_func(nested_args) + + # If `next_func` returns a list of tensors, `nest.flatten()` and + # `ops.convert_to_tensor()` would conspire to attempt to stack + # those tensors into a single tensor, because the customized + # version of `nest.flatten()` does not recurse into lists. Since + # it is more likely that the list arose from returning the + # result of an operation (such as `tf.py_func()`) that returns a + # list of not-necessarily-stackable tensors, we treat the + # returned value is a `tuple` instead. A user wishing to pack + # the return value into a single tensor can use an explicit + # `tf.stack()` before returning. + if isinstance(ret, list): + ret = tuple(ret) + + # Convert any `SparseTensorValue`s to `SparseTensor`s. + ret = nest.pack_sequence_as(ret, [ + sparse_tensor_lib.SparseTensor.from_value(t) + if sparse_tensor_lib.is_sparse(t) else t for t in nest.flatten(ret) + ]) + + self._output_classes = sparse.get_classes(ret) + self._output_shapes = nest.pack_sequence_as( + ret, [t.get_shape() for t in nest.flatten(ret)]) + self._output_types = nest.pack_sequence_as( + ret, [t.dtype for t in nest.flatten(ret)]) + + # Serialize any sparse tensors and convert result to tensors. + ret = nest.pack_sequence_as(ret, [ + ops.convert_to_tensor(t) + for t in nest.flatten(sparse.serialize_sparse_tensors(ret)) + ]) + return nest.flatten(ret) + + self._next_func = tf_next_func + self._next_func.add_to_graph(ops.get_default_graph()) + + @function.Defun(*nest.flatten( + sparse.as_dense_types(self._state_types, self._state_classes))) + def tf_finalize_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the state. + dense_shapes = sparse.as_dense_shapes(self._state_shapes, + self._state_classes) + for arg, shape in zip(args, nest.flatten(dense_shapes)): + arg.set_shape(shape) + + nested_args = nest.pack_sequence_as(self._state_classes, args) + nested_args = sparse.deserialize_sparse_tensors( + nested_args, self._state_types, self._state_shapes, + self._state_classes) + if _should_unpack_args(nested_args): + return finalize_func(*nested_args) + else: + return finalize_func(nested_args) + + self._finalize_func = tf_finalize_func + self._finalize_func.add_to_graph(ops.get_default_graph()) + + def _as_variant_tensor(self): + return gen_dataset_ops.generator_dataset( + nest.flatten(self._init_args) + self._init_func.captured_inputs, + self._next_func.captured_inputs, + self._finalize_func.captured_inputs, + init_func=self._init_func, + next_func=self._next_func, + finalize_func=self._finalize_func, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + + @property + def output_classes(self): + return self._output_classes + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + class ZipDataset(Dataset): """A `Dataset` that zips its inputs together.""" diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index e573fe01928b77dea55a782e4e86a00873346f07..4756ec74820bace5bea4e1f41ebe214420fe5c3d 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -44,8 +44,9 @@ GET_NEXT_CALL_WARNING_MESSAGE = ( "This often indicates that `Iterator.get_next()` is being called inside " "a training loop, which will cause gradual slowdown and eventual resource " "exhaustion. If this is the case, restructure your code to call " - "`next_element = iterator.get_next() once outside the loop, and use " - "`next_element` inside the loop.") + "`next_element = iterator.get_next()` once outside the loop, and use " + "`next_element` as the input to some computation that is invoked inside " + "the loop.") @tf_export("data.Iterator") @@ -303,7 +304,42 @@ class Iterator(object): dataset._as_variant_tensor(), self._iterator_resource, name=name) # pylint: disable=protected-access def get_next(self, name=None): - """Returns a nested structure of `tf.Tensor`s containing the next element. + """Returns a nested structure of `tf.Tensor`s representing the next element. + + In graph mode, you should typically call this method *once* and use its + result as the input to another computation. A typical loop will then call + @{tf.Session.run} on the result of that computation. The loop will terminate + when the `Iterator.get_next()` operation raises + @{tf.errors.OutOfRangeError}. The following skeleton shows how to use + this method when building a training loop: + + ```python + dataset = ... # A `tf.data.Dataset` object. + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + # Build a TensorFlow graph that does something with each element. + loss = model_function(next_element) + optimizer = ... # A `tf.train.Optimizer` object. + train_op = optimizer.minimize(loss) + + with tf.Session() as sess: + try: + while True: + sess.run(train_op) + except tf.errors.OutOfRangeError: + pass + ``` + + NOTE: It is legitimate to call `Iterator.get_next()` multiple times, e.g. + when you are distributing different elements to multiple devices in a single + step. However, a common pitfall arises when users call `Iterator.get_next()` + in each iteration of their training loop. `Iterator.get_next()` adds ops to + the graph, and executing each op allocates resources (including threads); as + a consequence, invoking it in every iteration of a training loop causes + slowdown and eventual resource exhaustion. To guard against this outcome, we + log a warning when the number of uses crosses a fixed threshold of + suspiciousness. Args: name: (Optional.) A name for the created operation. diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD index f0e90f67772d114142ccc218ed9f42b723a1b556..253588fc3b2986af3ab8c6be5b0b85f178c06336 100644 --- a/tensorflow/python/debug/BUILD +++ b/tensorflow/python/debug/BUILD @@ -957,7 +957,7 @@ cuda_py_test( cuda_py_test( name = "session_debug_grpc_test", - size = "medium", + size = "large", srcs = ["lib/session_debug_grpc_test.py"], additional_deps = [ ":debug_data", diff --git a/tensorflow/python/debug/wrappers/grpc_wrapper.py b/tensorflow/python/debug/wrappers/grpc_wrapper.py index 74d7c2b9e242f947a33c0bdb6508847808d69c0b..fb9494f57636e46e54ef230cf4803dbb6ccad0c7 100644 --- a/tensorflow/python/debug/wrappers/grpc_wrapper.py +++ b/tensorflow/python/debug/wrappers/grpc_wrapper.py @@ -17,6 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import signal +import sys import traceback # Google-internal import(s). @@ -137,6 +139,29 @@ class GrpcDebugWrapperSession(framework.NonInteractiveDebugWrapperSession): if not address.startswith(common.GRPC_URL_PREFIX) else address) +def _signal_handler(unused_signal, unused_frame): + try: + input_func = raw_input + except NameError: + # Python 3 does not have raw_input. + input_func = input + + while True: + response = input_func("\nSIGINT received. Quit program? (Y/n): ").strip() + if response in ("", "Y", "y"): + sys.exit(0) + elif response in ("N", "n"): + break + + +def register_signal_handler(): + try: + signal.signal(signal.SIGINT, _signal_handler) + except ValueError: + # This can happen if we are not in the MainThread. + pass + + class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): """A tfdbg Session wrapper that can be used with TensorBoard Debugger Plugin. @@ -185,6 +210,8 @@ class TensorBoardDebugWrapperSession(GrpcDebugWrapperSession): # sent to the debug servers. self._sent_graph_version = -1 + register_signal_handler() + def run(self, fetches, feed_dict=None, diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 0204254ccab109f4844f077df78902872d1156d5..6705cd31e291d2eab7aa8179e9b2b829f8970c18 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -345,6 +345,7 @@ class TensorBoardDebugHook(GrpcDebugHook): self._grpc_debug_server_addresses = grpc_debug_server_addresses self._send_traceback_and_source_code = send_traceback_and_source_code self._sent_graph_version = -1 + grpc_wrapper.register_signal_handler() def before_run(self, run_context): if self._send_traceback_and_source_code: diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py index dd71c0927b99a23929d986128e81c1184e9ec31b..14bcc60006228eeaabea241ee18d960174a9dbea 100644 --- a/tensorflow/python/eager/backprop.py +++ b/tensorflow/python/eager/backprop.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import functools import operator import threading @@ -42,6 +43,26 @@ from tensorflow.python.util import nest from tensorflow.python.util import tf_inspect +class _TensorCache(object): + """Simple cache which evicts items based on length in a FIFO manner.""" + + def __init__(self, max_items=256): + self._data = collections.OrderedDict() + self._max_items = max_items if max_items else 256 + + def put(self, key, value): + self._data[key] = value + + if len(self._data) > self._max_items: + self._data.popitem(last=False) + + def get(self, key): + return self._data.get(key, None) + + def flush(self): + self._data = {} + + _op_attr_type_cache = {} @@ -116,112 +137,6 @@ _gradient_functions_lock = threading.Lock() _tracing = False -# TODO(apassos) replace this with a mechanism which can happen at the op -# gradient function registration site, to be less error-prone -# TODO(apassos) add ops other than those in nn_grad and math_grad -_ops_which_dont_need_outputs = set([ - "Identity", - "MatMul", - "Conv2DBackpropInput", - "Conv2DBackpropFilter", - "Conv3D", - "Conv3DBackpropInputV2", - "AvgPool3D", - "AvgPool3DGrad", - "MaxPool3D", - "MaxPool3DGrad", - "MaxPool3DGradGrad", - "BiasAdd", - "BiasAddV1", - "BiasAddGrad", - "Relu6", - "Softplus", - "SoftplusGrad", - "Softsign", - "ReluGrad", - "Conv2D", - "DepthwiseConv2dNative", - "Dilation2D", - "AvgPool", - "AvgPoolGrad", - "BatchNormWithGlobalNormalization", - "L2Loss", - "Sum", - "Prod", - "SegmentSum", - "SegmentMean", - "SparseSegmentSum", - "SparseSegmentMean", - "SparseSegmentSqrtN", - "SegmentMin", - "SegmentMax", - "UnsortedSegmentSum", - "UnsortedSegmentMax", - "UnsortedSegmentMin", - "UnsortedSegmentProd", - "Abs", - "Neg", - "ReciprocalGrad", - "Square", - "Expm1", - "Log", - "Log1p", - "TanhGrad", - "SigmoidGrad", - "Sign", - "Sin", - "Cos", - "Tan", - "Add", - "Sub", - "Mul", - "Div", - "RealDiv", - "Maximum", - "Minimum", - "SquaredDifference", - "Select", - "SparseMatMul", - "BatchMatMul", - "Complex", - "Real", - "Imag", - "Angle", - "Conj", - "Cast", - "Cross", - "Cumsum", - "Cumprod", - "ReadVariableOp", - "VarHandleOp", - "Shape", -]) - -_ops_which_dont_need_inputs = set([ - "Identity", - "Softmax", - "LogSoftmax", - "BiasAdd", - "Relu", - "Elu", - "Selu", - "SparseSoftmaxCrossEntropyWithLogits", - "Neg", - "Inv", - "Reciprocal", - "Sqrt", - "Exp", - "Tanh", - "Sigmoid", - "Real", - "Imag", - "Conj", - "ReadVariableOp", - "VarHandleOp", - "Shape", -]) - - # TODO(agarwal): use an automatic mechanism for handling None arguments to # gradient functions. # Some gradient functions can accept None arguments for gradients. The following @@ -240,57 +155,25 @@ _grad_fn_accepts_none_for_indices = { } -def _record_gradient(op_name, inputs, attrs, results, name): - """Records gradients for a TensorFlow operation. - - Args: - op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to - execute. - inputs: A flat list of Tensor object inputs to the operation. - attrs: A tuple with alternating string attr names and attr values for this - operation. - results: The results of the operation (as a flat list). - name: Customized name for the operation. - - Returns: - A list of maybe-wrapped results. Either Tensors or TensorNodes. - - Raises: - An exception on error. - """ - if not tape.could_possibly_record(): - return - - if op_name in _ops_which_dont_need_outputs: - op_outputs = None - else: - # TODO(apassos) this line creates a weak circular reference where the - # backprop function keeps an output alive which in turn keeps the tape entry - # alive which keeps the backprop function alive. Figure out how to break - # this up without breaking second derivatives of ops like Exp whose - # gradients depend only on the outputs. - op_outputs = results - - if op_name in _ops_which_dont_need_inputs: - op_inputs = None - else: - op_inputs = inputs - - num_inputs = len(inputs) +def _get_backward_fn(op_name, attrs, num_inputs, op_inputs, op_outputs): def grad_fn(*orig_outputs): - """Generated gradient function.""" result = _magic_gradient_function(op_name, attrs, num_inputs, op_inputs, op_outputs, orig_outputs) if _tracing: - print("Gradient for", (name if name else op_name), "inputs", op_inputs, - "output_grads", orig_outputs, "gradients", result) + print("Gradient for", op_name, "inputs", op_inputs, "output_grads", + orig_outputs, "gradients", result) return nest.flatten(result) - tape.record_operation(op_name, results, inputs, grad_fn) - if _tracing: - print("Computed op", (name if name else op_name), "inputs", inputs, - "outputs", results) + return grad_fn + + +pywrap_tensorflow.TFE_Py_RegisterBackwardFunctionGetter(_get_backward_fn) + + +def _record_gradient(op_name, inputs, attrs, results, name): + return pywrap_tensorflow.TFE_Py_RecordGradient(op_name, inputs, attrs, + results, name) execute.record_gradient = _record_gradient @@ -357,6 +240,7 @@ def implicit_val_and_grad(f): tape.pop_tape(this_tape) # Sorting variables by id, which is monotonically increasing in construction # order. This ensures unique order across executions. + # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc. variables = list(sorted(this_tape.watched_variables(), key=lambda v: v.handle._id)) # pylint: disable=protected-access sources = [x.handle for x in variables] @@ -618,7 +502,7 @@ def val_and_grad_function(f, params=None): return decorated -def make_vjp(f, params=None): +def make_vjp(f, params=None, persistent=True): """Returns a function that computes f and is vjp w.r.t. params. The term "vjp" here is an abbreviation for vector-jacobian product. @@ -627,6 +511,8 @@ def make_vjp(f, params=None): f: the function to be differentiated. params: the parameters (numbers or names) to differentiate with respect to. A value of None will differentiate with respect to all parameters. + persistent: Boolean controlling whether the VJP function can be re-used. + Must be True or False. Returns: A function, which when called, returns a tuple (value, vjp), where: @@ -654,7 +540,7 @@ def make_vjp(f, params=None): """Computes the value and gradient of the decorated function.""" parameter_positions = _get_arg_spec(f, params, args) assert not kwds, "The gradient function can't take keyword arguments." - this_tape = tape.push_new_tape() + this_tape = tape.push_new_tape(persistent=persistent) try: sources = [] args = [ @@ -736,8 +622,7 @@ def _num_elements(grad): raise ValueError("`grad` not a Tensor or IndexedSlices.") -_last_zero_shape_dtype = [None, None] -_last_zero = [None] +_zeros_cache = _TensorCache() def _fast_fill(value, shape, dtype): @@ -746,14 +631,17 @@ def _fast_fill(value, shape, dtype): def _zeros(shape, dtype): """Wraps array_ops.zeros to cache last zero for a given shape and dtype.""" + device = context.context().device_name if dtype == dtypes.variant: # TODO(apassos): need to save enough information about variant tensors to do # a zeros return None - if [shape, dtype] != _last_zero_shape_dtype: - _last_zero_shape_dtype[:] = [shape, dtype] - _last_zero[0] = _fast_fill(0, shape, dtype) - return _last_zero[0] + cache_key = shape, dtype, device + cached = _zeros_cache.get(cache_key) + if cached is None: + cached = _fast_fill(0, shape, dtype) + _zeros_cache.put(cache_key, cached) + return cached def _ones(shape, dtype): @@ -861,7 +749,11 @@ class GradientTape(object): tape.watch(t) def watched_variables(self): - return self._tape.watched_variables() + # Sorting variables by id, which is monotonically increasing in construction + # order. This ensures unique order across executions. + # TODO(josh11b): Move the sort to the C++ implementation in pywrap_tfe_src.cc. + return list(sorted(self._tape.watched_variables(), + key=lambda v: v.handle._id)) # pylint: disable=protected-access def gradient(self, target, sources, output_gradients=None): """Computes the gradient using information traced by the tape. diff --git a/tensorflow/python/eager/backprop_test.py b/tensorflow/python/eager/backprop_test.py index a12113893ab3eac671e8138472bc95e9d8b89499..48fd1707643511413f501e8b09ba3d86fcd8e904 100644 --- a/tensorflow/python/eager/backprop_test.py +++ b/tensorflow/python/eager/backprop_test.py @@ -115,6 +115,19 @@ class BackpropTest(test.TestCase): with self.assertRaises(RuntimeError): backprop.gradients_function(f)(constant_op.constant(1.0)) + def testGradientsFunctionInCustomGradient(self): + + @custom_gradient.custom_gradient + def f(x): + (y,) = backprop.gradients_function(lambda x: x * x)(x) + + def grad(dy): + return [2 * dy] + + return y, grad + + self.assertAllEqual(f(1.0), 2.0) + def testImplicitGradOverEmbeddingLookup(self): batch_size = 8 embedding_size = 512 @@ -205,11 +218,22 @@ class BackpropTest(test.TestCase): def f(x): return x * x - wrapped_fn = backprop.make_vjp(f) + wrapped_fn = backprop.make_vjp(f, persistent=False) result, vjp = wrapped_fn(constant_op.constant(3.0)) self.assertAllEqual(result, 9.0) self.assertAllEqual(vjp(2.0)[0], 12.0) + def testPersistentMakeVJP(self): + + def f(x): + return x * x + + wrapped_fn = backprop.make_vjp(f, persistent=True) + _, vjp = wrapped_fn(constant_op.constant(3.0)) + vjp_result1 = vjp(2.0)[0] + vjp_result2 = vjp(2.0)[0] + self.assertAllEqual(vjp_result1, vjp_result2, 12.0) + @test_util.assert_no_new_tensors def testGradGrad(self): diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py index ee3c10633e1cb849e319f2f5490e5beb5dd15c80..0e40d8a5c0a582ab27d95735dd917e2a5daabe09 100644 --- a/tensorflow/python/eager/core_test.py +++ b/tensorflow/python/eager/core_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn_ops @@ -99,6 +100,18 @@ class TFETest(test_util.TensorFlowTestCase): self.assertEqual(len(cpu_stats.node_stats), 1) self.assertEqual(cpu_stats.node_stats[0].node_name, 'Add') + def testShouldCopy(self): + if not context.context().num_gpus(): + self.skipTest('No devices other than CPUs found') + with ops.device('gpu:0'): + x = constant_op.constant(1.0) + y = array_ops.identity(x) + # The value we're testing y.device against will depend on what the behavior + # of not explicitly specifying a device in the context is. This behavior is + # subject to change (for example, in the future we may want to use GPUs, if + # available, when no device is explicitly provided) + self.assertEqual(y.device, '/job:localhost/replica:0/task:0/device:CPU:0') + def testContextStackContainsEagerMode(self): # Eager execution has been enabled, and no other context # switch has occurred, so `context_stack` should contain diff --git a/tensorflow/python/eager/custom_gradient.py b/tensorflow/python/eager/custom_gradient.py index 05460ff9968312528d87f5fc2ad0495b4da2ad1a..fb932a937206a9500996e6d1ae721a8294c676d0 100644 --- a/tensorflow/python/eager/custom_gradient.py +++ b/tensorflow/python/eager/custom_gradient.py @@ -71,11 +71,10 @@ def custom_gradient(f): input_tensors = [tf_ops.convert_to_tensor(x) for x in args] - with tape.stop_recording(): - result, grad_fn = f(*args, **kwargs) - flat_result = nest.flatten(result) - # TODO(apassos) consider removing the identity below. - flat_result = [gen_array_ops.identity(x) for x in flat_result] + result, grad_fn = f(*args, **kwargs) + flat_result = nest.flatten(result) + # TODO(apassos) consider removing the identity below. + flat_result = [gen_array_ops.identity(x) for x in flat_result] def actual_grad_fn(*outputs): return nest.flatten(grad_fn(*outputs)) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 767d719ea69440e80853b41bf3ec992f286a635a..b3317bd3235f432220d9d5d135f1af18a6f43310 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -36,6 +36,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_module from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.util import compat from tensorflow.python.util import nest @@ -195,33 +196,66 @@ ops.register_tensor_conversion_function( ops.EagerTensor, _convert_to_graph_tensor, priority=-1) -class _CapturingContext(object): - """Tracks references to Tensors outside this context while it is active.""" +# pylint: disable=invalid-name +class HelperContext(object): + """ControlFlowContext with a customizable AddOp method.""" - def __init__(self): - # known_ops are ops which are created while this context is active - self.known_ops = set() + def __init__(self, add_op_internal): + self._add_op_internal = add_op_internal + self._values = set() # control flow code sometimes updates this. + + def _AddOpInternal(self, op): + self._add_op_internal(op) + + @property + def outer_context(self): + return self._outer_context + + def GetWhileContext(self): + if self._outer_context: + return self._outer_context.GetWhileContext() + + def IsWhileContext(self): + return False + + def IsCondContext(self): + return False - # captured_tensors are all tensors referenced to by ops in this context but - # not produced in it - self.captured_tensors = set() + def IsXLAContext(self): + return False def AddOp(self, op): # pylint: disable=invalid-name - if op.type in ["Variable", "VariableV2", "VarHandleOp"]: - raise ValueError("tfe.defun cannot capture variables created without " - "using tf.get_variable. Op: %s" % op) - self.known_ops.add(op) - for i in op.inputs: - if i.op not in self.known_ops: - self.captured_tensors.add(i) + self._AddOpInternal(op) + if self._outer_context: + self._outer_context.AddOp(op) + + def AddName(self, _): + pass + + def AddInnerOp(self, op): + self._AddOpInternal(op) + if self._outer_context: + self._outer_context.AddInnerOp(op) + + def AddValue(self, val): + if self._outer_context: + return self._outer_context.AddValue(val) + else: + return val def __enter__(self): + # pylint: disable=protected-access self._g = ops.get_default_graph() - self._old = self._g._get_control_flow_context() # pylint: disable=protected-access - self._g._set_control_flow_context(self) # pylint: disable=protected-access + self._outer_context = self._g._get_control_flow_context() + self._g._set_control_flow_context(self) + self._nested_contexts = ( + self._outer_context._nested_contexts + if self._outer_context is not None else None) + # pylint: enable=protected-access - def __exit__(self, _, __, ___): # pylint: disable=invalid-name - self._g._set_control_flow_context(self._old) # pylint: disable=protected-access + def __exit__(self, *_): + self._g._set_control_flow_context(self._outer_context) # pylint: disable=protected-access +# pylint: enable=invalid-name def _forward_name(n): @@ -367,7 +401,20 @@ class GraphModeFunction(object): def _construct_backprop_function(self): """Constructs the backprop function object for this function.""" with self._graph.as_default(), context.graph_mode(): - c = _CapturingContext() + c_known_ops = set() + c_captured_tensors = set() + + def add_op_internal(op): + if op.type in ["Variable", "VariableV2", "VarHandleOp"]: + raise ValueError("tfe.defun cannot capture variables created without " + "using tf.get_variable. Op: %s" % op) + c_known_ops.add(op) + for i in op.inputs: + if i.op not in c_known_ops: + c_captured_tensors.add(i) + + c = HelperContext(add_op_internal) + with c: filtered_outputs = [x for x in self._returns if x is not None] self._out_grad_placeholders = [ @@ -381,7 +428,7 @@ class GraphModeFunction(object): grad for grad in _flatten(in_gradients) if grad is not None) output_shapes = tuple(grad.shape for grad in backward_outputs) - captures = list(sorted(c.captured_tensors, key=lambda x: x.name)) + captures = list(sorted(c_captured_tensors, key=lambda x: x.name)) forward_name = _forward_name(self._func_name) self._forward_fdef = _EagerDefinedFunction( forward_name, self._graph, self._ops, self._input_placeholders, @@ -394,7 +441,7 @@ class GraphModeFunction(object): # means rerunning the function-defining code will always define the same # function, which is useful if we serialize this etc. function_def_ops = tuple(x - for x in sorted(c.known_ops, key=lambda x: x.name) + for x in sorted(c_known_ops, key=lambda x: x.name) if x not in all_ignored_ops) bname = _backward_name(self._func_name) self._backward_function = GraphModeFunction( @@ -592,10 +639,16 @@ def _defun_internal(name, func, args, kwds): with tmp_graph.as_default(): func_inputs = _get_defun_inputs(args) + def convert(x): + if x is None: + return None + return ops.convert_to_tensor_or_indexed_slices(x) + with capture_tensors(captures): this_tape = tape.push_new_tape() try: func_outputs = func(*func_inputs, **kwds) + func_outputs = nest.map_structure(convert, func_outputs) finally: tape.pop_tape(this_tape) variables = this_tape.watched_variables() @@ -813,3 +866,208 @@ def make_defun_op(func, *args, **kwds): if any(isinstance(x, ops.EagerTensor) for x in kwds.values()): raise ValueError("Tensor keyword arguments are not supported.") return _defun_internal(name, func, args, kwds) + + +class AutomaticControlDependencies(object): + """Context manager to automatically add control dependencies. + + Code under this context manager will act as if a sensible set of control + dependencies were present. More specifically: + 1. All stateful ops in the scope will execute + 2. Stateful ops which modify the same resource will execute in program order + + Note: creating variables in an automatic control dependencies context is not + supported (the value of the variables will never change as they will keep + getting reinitialized). + + NOT THREAD SAFE + """ + + def __init__(self): + self._returned_tensors = set() + + def mark_as_return(self, tensor): + self._returned_tensors.add(tensor) + + def __enter__(self): + if context.in_eager_mode(): + return self + # This code assumes no other thread is adding ops to the graph while + # we're adding ops to the graph. + # TODO(apassos): Fix this by locking the graph or using a temporary + # graph (but that would mess up devices and collections at least, + # probably other things as well). + self._graph = ops.get_default_graph() + self._n_operations = len(self._graph.get_operations()) + return self + + def _process_switch(self, switch_op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource): + """Processes a switch node for a resource input. + + When tensorflow creates a cond, it creates a control flow context for each + branch of the cond. Each external tensor accessed by that branch is routed + through a switch op, which gets created in the graph _after_ the op which + uses that tensor get created. + + If the resource comes from another switch op we process that one first. + + _process_switch creates a corresponding merge node for the switch node. This + merge node is added to the outer control flow context of the switch + node. We also ensure that: + + 1. The switch node executes after the previous op which used the resource + tensor + + 2. Any op which uses a resource output of the switch node executes before + the merge for the switch node. + + 3. The next op which uses the input resource to the switch node (which + might be another switch node for the other branch of the conditional) + will execute after the merge node is done. + + 4. The merge node is marked as must_run so it will run even if no + subsequent operation uses the resource. + + Args: + switch_op: the switch op to be processed + ops_which_must_run: the set of ops which must run + last_op_using_resource_tensor: map from resource tensor to last op using + it + merge_for_resource: map from resource tensor to merge which must follow + all usages of it. + """ + inp = switch_op.inputs[0] + if inp.dtype == dtypes_module.resource and inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, merge_for_resource) + if switch_op.outputs[0] in merge_for_resource: + return + new_merge = control_flow_ops.merge(switch_op.outputs, + name="artificial_merge") + new_merge[0].op._control_flow_context = ( # pylint: disable=protected-access + switch_op._control_flow_context.outer_context) # pylint: disable=protected-access + # Ensures the merge always runs + ops_which_must_run.add(new_merge[0].op) + if inp in last_op_using_resource_tensor: + # Ensures the switch exectutes after the previous op using the resource. + switch_op._add_control_input(last_op_using_resource_tensor[inp]) # pylint: disable=protected-access + # Ensure the next op outside the cond happens after the merge. + last_op_using_resource_tensor[inp] = new_merge[0].op + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(new_merge[0].op) # pylint: disable=protected-access + for o in switch_op.outputs: + # Ensures the merge will execute after all ops inside the cond + merge_for_resource[o] = new_merge[0].op + + def __exit__(self, unused_type, unused_value, unused_traceback): + if context.in_eager_mode(): + return + + if self._graph is not ops.get_default_graph(): + raise RuntimeError( + "Graph changed while trying to add control dependencies.") + + # map from resource tensor to the last op which used it + last_op_using_resource_tensor = {} + # set of conditional and loop exits + ops_which_must_run = set() + # merge which must depend on ops which use this resource + merge_for_resource = {} + + new_operations = self._graph.get_operations()[self._n_operations:] + + # Ensures that uses of resource tensors get serialized properly and all + # execute. This is done by keeping a map from resource tensor to the last op + # in graph-construction order which used it (last_op_using_resource_tensor). + # + # Conditionals are written in TensorFlow such that every external tensor + # accessed in the conditional goes through a switch op and every return + # tensor (it's guaranteed that there will be at least one) goes through a + # merge op. + # + # To handle conditionals, switches are handled in a special way (see + # comments for _process_switch). Merge nodes created by TF's conditional + # logic (as opposed to by _process_switch) are forced to run and also get a + # control dependency added to them to ensure all stateful ops inside their + # control flow context run. + # + # We also ensure that if an op is using a resource output by a switch node + # (that is, a resource tensor for which there's a value in + # merge_for_resource) this op will run before the merge for that resource. + # + # We try to add control inputs to nodes respecting their control flow + # contexts to avoid dead nodes propagating everywhere and leading to + # "retval[0] doesn't have value" errors. If a node gets a control dependency + # on a dead node (i.e. a note from an untaken control flow branch) that node + # will be marked as dead unless it's a merge node. + # + # TODO(apassos): serialize non-resource-taking stateful ops as well, and + # test that it works. Support while loops. Support init_scope escaping from + # this. + for op in new_operations: + control_inputs = set() + # Ensure stateful ops run + if self._graph._registered_ops[op.type].is_stateful: # pylint: disable=protected-access + ops_which_must_run.add(op) + # Ignore switches (they're handled separately) + if op.type == "Switch" and op.inputs[0].dtype == dtypes_module.resource: + continue + # Make merges trigger all other computation which must run + if op.type == "Merge": + for o in ops_which_must_run: + op._add_control_input(o) # pylint: disable=protected-access + for inp in o.inputs: + if inp in last_op_using_resource_tensor: + last_op_using_resource_tensor[inp] = op + ops_which_must_run = set([op]) + continue + for inp in op.inputs: + if inp.dtype == dtypes_module.resource: + # Deal with switches, finally. + if inp.op.type == "Switch": + self._process_switch(inp.op, ops_which_must_run, + last_op_using_resource_tensor, + merge_for_resource) + # Ensure uses of resources are serialized + if inp in last_op_using_resource_tensor: + if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access + is op._control_flow_context): # pylint: disable=protected-access + control_inputs.add(last_op_using_resource_tensor[inp]) + # Ensure merges happen after the closing of a cond block + if inp in merge_for_resource: + merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access + last_op_using_resource_tensor[inp] = op + control_inputs = [c for c in control_inputs + if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access + op._add_control_inputs(control_inputs) # pylint: disable=protected-access + + # Ensure all ops which must run do run + for r in self._returned_tensors: + r.op._add_control_inputs( # pylint: disable=protected-access + [o for o in ops_which_must_run + if o._control_flow_context is r.op._control_flow_context]) # pylint: disable=protected-access + + +def automatic_control_dependencies(f): + """Wraps f to automatically insert control dependencies. + + The inserted dependencies ensure that: + 1. All stateful ops in f run when the result of f runs + 2. Updates to the same resources happen in order. + + Args: + f: the function to be wrapped. + + Returns: + The wrapped function. + """ + + def wrapper(*args, **kwds): + with AutomaticControlDependencies() as a: + result = f(*args, **kwds) + for t in nest.flatten(result): + a.mark_as_return(t) + return result + + return tf_decorator.make_decorator(f, wrapper) diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py index 3e8e67ac7e242887e1c4f7d89a2e2bc395db22fe..431d9388c0ee97eda197142ec97b9448d985b04b 100644 --- a/tensorflow/python/eager/function_test.py +++ b/tensorflow/python/eager/function_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import variable_scope @@ -595,5 +596,172 @@ class FunctionTest(test.TestCase): create_variable() +class AutomaticControlDependenciesTest(test.TestCase): + + def testBasic(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + with function.AutomaticControlDependencies() as c: + v.assign(v + 1) + v.assign(2 * v) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(), 4.0) + + def testCondMustRun(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 6.0) + + def testCondMustRunSeparateRead(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1) + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + one = constant_op.constant(1.0) + c.mark_as_return(one) + one.eval(feed_dict={p: False}) + self.assertAllEqual(v.read_value().eval(), 5.0) + one.eval(feed_dict={p: True}) + self.assertAllEqual(v.read_value().eval(), 6.0) + + def testCondNested(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + q = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + v.assign(v + 1, name='true') + return 1.0 + + def false_fn(): + + def inner_true_fn(): + v.assign(v * 2, name='false_true') + return 2.0 + + def inner_false_fn(): + v.assign(v * 3, name='false_false') + return 3.0 + + control_flow_ops.cond(q, inner_true_fn, inner_false_fn) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + with ops.name_scope('final'): + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False, q: False}), 3.0) + self.assertAllEqual(val.eval(feed_dict={p: False, q: True}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: True}), 7.0) + self.assertAllEqual(val.eval(feed_dict={p: True, q: False}), 8.0) + + def testCondOneBranch(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 5.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 5.0) + + def testCondOneBranchUpdateBefore(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + v.assign(v * 2) + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 6.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 12.0) + + def testCondOneBranchUpdateAfter(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + p = array_ops.placeholder(dtype=dtypes.bool) + with function.AutomaticControlDependencies() as c: + + def true_fn(): + return 0.0 + + def false_fn(): + v.assign(v + 4) + return 1.0 + + control_flow_ops.cond(p, true_fn, false_fn) + v.assign(v * 2) + val = v.read_value() + c.mark_as_return(val) + self.assertAllEqual(val.eval(feed_dict={p: False}), 10.0) + self.assertAllEqual(val.eval(feed_dict={p: True}), 20.0) + + def testDecorator(self): + with context.graph_mode(), self.test_session(): + v = resource_variable_ops.ResourceVariable(1.0) + variables.global_variables_initializer().run() + + @function.automatic_control_dependencies + def f(): + v.assign(v + 1) + v.assign(2 * v) + return v.read_value() + + self.assertAllEqual(f().eval(), 4.0) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/eager/pywrap_tensor.cc b/tensorflow/python/eager/pywrap_tensor.cc index 6fa076507d11ab9c88891cbeb0a4fb3959e4e99d..3ec2109d323b4f0b2a7e2de0ee13c3317f536a68 100644 --- a/tensorflow/python/eager/pywrap_tensor.cc +++ b/tensorflow/python/eager/pywrap_tensor.cc @@ -185,6 +185,12 @@ typedef struct EagerTensor { // This stores `_keras_mask` object and is set by Tensorflow layers. PyObject* keras_mask; + + // We store a status object here as an optimization to avoid allocating a new + // Status objects on different functions that operate on EagerTensor and need + // to use a TF_Status object. However note that accesses to `status` are not + // thread-safe. + TF_Status* status; } EagerTensor; // tp_init for EagerTensor. @@ -195,6 +201,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { self->handle_data = Py_None; Py_INCREF(Py_None); self->keras_mask = Py_None; + self->status = TF_NewStatus(); PyObject* value; PyObject* context = nullptr; PyObject* device = nullptr; @@ -269,17 +276,17 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { } TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); if (desired_dtype >= 0 && desired_dtype != handle_dtype) { - auto out_status = tensorflow::make_safe(TF_NewStatus()); handle = tensorflow::make_safe( EagerCast(GetContext(context), handle.get(), handle_dtype, - static_cast(desired_dtype), out_status.get())); - if (TF_GetCode(out_status.get()) != TF_OK) { - PyErr_SetString( - PyExc_ValueError, - tensorflow::strings::StrCat("Error while casting from DataType ", - handle_dtype, " to ", desired_dtype, ". ", - TF_Message(out_status.get())) - .c_str()); + static_cast(desired_dtype), self->status)); + if (TF_GetCode(self->status) != TF_OK) { + PyErr_SetString(PyExc_ValueError, + tensorflow::strings::StrCat( + "Error while casting from DataType ", handle_dtype, + " to ", desired_dtype, ". ", TF_Message(self->status)) + .c_str()); + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); return -1; } handle_dtype = TFE_TensorHandleDataType(handle.get()); @@ -323,6 +330,7 @@ int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { // tp_dealloc for EagerTensor. void EagerTensor_dealloc(EagerTensor* self) { + TF_DeleteStatus(self->status); Py_DECREF(self->handle_data); Py_DECREF(self->keras_mask); TFE_DeleteTensorHandle(self->handle); @@ -348,12 +356,21 @@ static PyObject* EagerTensor_datatype_enum(EagerTensor* self) { // Getter for `_shape_tuple`. static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { auto handle = self->handle; - int n = TFE_TensorHandleNumDims(handle); + int n = TFE_TensorHandleNumDims(handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } PyObject* shape = PyTuple_New(n); if (PyErr_Occurred()) return nullptr; for (int i = 0; i < n; ++i) { - PyObject* dim = PyLong_FromLongLong(TFE_TensorHandleDim(handle, i)); - if (dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { + PyObject* dim = + PyLong_FromLongLong(TFE_TensorHandleDim(handle, i, self->status)); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError) || + dim == nullptr || PyTuple_SetItem(shape, i, dim) != 0) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); Py_DECREF(shape); if (dim != nullptr) Py_DECREF(dim); PyErr_SetString(PyExc_RuntimeError, "Error while creating shape"); @@ -365,10 +382,16 @@ static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { // Getter for `_rank`. static PyObject* EagerTensor_rank(EagerTensor* self) { + int num_dims = TFE_TensorHandleNumDims(self->handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } #if PY_MAJOR_VERSION < 3 - return PyInt_FromLong(TFE_TensorHandleNumDims(self->handle)); + return PyInt_FromLong(num_dims); #else - return PyLong_FromLong(TFE_TensorHandleNumDims(self->handle)); + return PyLong_FromLong(num_dims); #endif } @@ -437,10 +460,16 @@ static PyObject* EagerTensor_numpy(EagerTensor* self) { // Getter `device`. static PyObject* EagerTensor_device(EagerTensor* self) { + const char* device = TFE_TensorHandleDeviceName(self->handle, self->status); + if (MaybeRaiseExceptionFromTFStatus(self->status, PyExc_ValueError)) { + // Cleanup self->status before returning. + TF_SetStatus(self->status, TF_OK, ""); + return nullptr; + } #if PY_MAJOR_VERSION >= 3 - return PyUnicode_FromString(TFE_TensorHandleDeviceName(self->handle)); + return PyUnicode_FromString(device); #else - return PyBytes_FromString(TFE_TensorHandleDeviceName(self->handle)); + return PyBytes_FromString(device); #endif } @@ -576,6 +605,7 @@ PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle) { Py_INCREF(Py_None); t->keras_mask = Py_None; t->handle = handle; + t->status = TF_NewStatus(); } return reinterpret_cast(t); } @@ -673,6 +703,7 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) { auto tensor = tensorflow::make_safe(TF_AllocateTensor( TF_INT32, &num_tensors_int, /*num_dims=*/1, /*len=*/4 * num_tensors_int)); int32_t* data = reinterpret_cast(TF_TensorData(tensor.get())); + auto status = tensorflow::make_safe(TF_NewStatus()); for (Py_ssize_t i = 0; i < num_tensors; ++i) { PyObject* tensor_obj = PyList_GET_ITEM(tensor_list, i); if (!EagerTensor_CheckExact(tensor_obj)) { @@ -687,21 +718,27 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensor_list, int slice_dim) { EagerTensor* t = reinterpret_cast(tensor_obj); TFE_TensorHandle* handle = t->handle; - if (slice_dim >= TFE_TensorHandleNumDims(handle)) { - PyErr_SetString(PyExc_IndexError, - tensorflow::strings::StrCat( - "Slice dimension (", slice_dim, - ") must be smaller than rank of all " - "tensors, but tensor at index ", - i, " has rank ", TFE_TensorHandleNumDims(handle)) - .c_str()); + int num_dims = TFE_TensorHandleNumDims(handle, status.get()); + if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) { + return nullptr; + } + if (slice_dim >= num_dims) { + PyErr_SetString( + PyExc_IndexError, + tensorflow::strings::StrCat("Slice dimension (", slice_dim, + ") must be smaller than rank of all " + "tensors, but tensor at index ", + i, " has rank ", num_dims) + .c_str()); + return nullptr; + } + int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get()); + if (MaybeRaiseExceptionFromTFStatus(status.get(), PyExc_ValueError)) { return nullptr; } - int64_t dim = TFE_TensorHandleDim(handle, slice_dim); data[i] = dim; } - auto status = tensorflow::make_safe(TF_NewStatus()); TFE_TensorHandle* handle = TFE_NewTensorHandle(tensor.get(), status.get()); if (TF_GetCode(status.get()) != TF_OK) { PyErr_SetString( diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h index 16b7d1a119a409d1d0a77b220d5d0945b280b638..f9692a8910aa6354c7ed81c7e88aed882058f276 100644 --- a/tensorflow/python/eager/pywrap_tfe.h +++ b/tensorflow/python/eager/pywrap_tfe.h @@ -59,6 +59,15 @@ PyObject* TFE_Py_RegisterExceptionClass(PyObject* e); // This function is not thread-safe. PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e); +// Registers e as the backward_function_getter. +// The registered function creates a backward function (a function that can +// return the gradient of the inputs an op given the gradient of it's outputs). +// The registered function will be passed the following arguments: +// op_name, attrs, num_inputs, op_inputs, op_outputs +// +// This function is not thread-safe. +PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e); + // Returns 0 if 'status' is TF_OK. Otherwise, raises an exception (using // `exception` if not nullptr, else using the class registered via // TFE_Py_RegisterExceptionClass), and returns -1. @@ -165,6 +174,11 @@ PyObject* TFE_Py_TapeGradient(PyObject* tape, PyObject* vspace, // directive. PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args); +// Record the gradient for a given op. +PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, + PyObject* attrs, PyObject* results, + PyObject* name); + // Returns the set of variables watched by the given tape. PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape); diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc index cabbcc48fd56563a50591cc6adabc3af75918401..30e08c8e6531739e3db66a94308e4ce2aff61f11 100644 --- a/tensorflow/python/eager/pywrap_tfe_src.cc +++ b/tensorflow/python/eager/pywrap_tfe_src.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/compactptrset.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mutex.h" @@ -575,6 +576,9 @@ PyObject* exception_class GUARDED_BY(exception_class_mutex) = nullptr; // Python subclass of Exception that is created to signal fallback. PyObject* fallback_exception_class = nullptr; +// Python function that returns a backward_function. +PyObject* backward_function_getter = nullptr; + tensorflow::mutex _uid_mutex(tensorflow::LINKER_INITIALIZED); tensorflow::int64 _uid GUARDED_BY(_uid_mutex) = 0; @@ -647,6 +651,23 @@ PyObject* TFE_Py_RegisterFallbackExceptionClass(PyObject* e) { } } +PyObject* TFE_Py_RegisterBackwardFunctionGetter(PyObject* e) { + if (backward_function_getter != nullptr) { + Py_DECREF(backward_function_getter); + } + if (!PyCallable_Check(e)) { + backward_function_getter = nullptr; + PyErr_SetString(PyExc_TypeError, + "TFE_Py_RegisterBackwardFunctionGetter: " + "Registered object should be function."); + return nullptr; + } else { + Py_INCREF(e); + backward_function_getter = e; + Py_RETURN_NONE; + } +} + void RaiseFallbackException(const char* message) { if (fallback_exception_class != nullptr) { PyErr_SetObject(fallback_exception_class, Py_BuildValue("s", message)); @@ -1062,16 +1083,10 @@ PyObject* TFE_Py_TapeWatchedVariables(PyObject* tape) { return result; } -void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, - PyObject* input_tensors, - PyObject* backward_function) { - if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) { - return; - } - std::vector input_ids = MakeTensorIDList(input_tensors); - if (PyErr_Occurred()) { - return; - } +namespace { +void TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + const std::vector& input_ids, + PyObject* backward_function) { std::vector output_info; PyObject* seq = PySequence_Fast(output_tensors, "expected a sequence of integer tensor ids"); @@ -1110,6 +1125,19 @@ void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, [backward_function]() { Py_DECREF(backward_function); }); } } +} // namespace + +void TFE_Py_TapeSetRecordOperation(PyObject* op_type, PyObject* output_tensors, + PyObject* input_tensors, + PyObject* backward_function) { + if (GetTapeSet()->empty() || *ThreadTapeIsStopped()) { + return; + } + std::vector input_ids = MakeTensorIDList(input_tensors); + if (PyErr_Occurred()) return; + + TapeSetRecordOperation(op_type, output_tensors, input_ids, backward_function); +} void TFE_Py_TapeSetDeleteTrace(tensorflow::int64 tensor_id) { for (TFE_Py_Tape* tape : SafeTapeSet()) { @@ -1430,6 +1458,164 @@ bool RaiseIfNotPyList(PyObject* list, const string& attr_name) { return true; } +bool OpDoesntRequireOutput(const string& op_name) { + static tensorflow::gtl::FlatSet* ops_that_dont_require_outputs = + new tensorflow::gtl::FlatSet({ + "Identity", + "MatMul", + "Conv2DBackpropInput", + "Conv2DBackpropFilter", + "Conv3D", + "Conv3DBackpropInputV2", + "AvgPool3D", + "AvgPool3DGrad", + "MaxPool3D", + "MaxPool3DGrad", + "MaxPool3DGradGrad", + "BiasAdd", + "BiasAddV1", + "BiasAddGrad", + "Relu6", + "Softplus", + "SoftplusGrad", + "Softsign", + "ReluGrad", + "Conv2D", + "DepthwiseConv2dNative", + "Dilation2D", + "AvgPool", + "AvgPoolGrad", + "BatchNormWithGlobalNormalization", + "L2Loss", + "Sum", + "Prod", + "SegmentSum", + "SegmentMean", + "SparseSegmentSum", + "SparseSegmentMean", + "SparseSegmentSqrtN", + "SegmentMin", + "SegmentMax", + "UnsortedSegmentSum", + "UnsortedSegmentMax", + "Abs", + "Neg", + "ReciprocalGrad", + "Square", + "Expm1", + "Log", + "Log1p", + "TanhGrad", + "SigmoidGrad", + "Sign", + "Sin", + "Cos", + "Tan", + "Add", + "Sub", + "Mul", + "Div", + "RealDiv", + "Maximum", + "Minimum", + "SquaredDifference", + "Select", + "SparseMatMul", + "BatchMatMul", + "Complex", + "Real", + "Imag", + "Angle", + "Conj", + "Cast", + "Cross", + "Cumsum", + "Cumprod", + "ReadVariableOp", + "VarHandleOp", + "Shape", + }); + + return ops_that_dont_require_outputs->find(op_name) != + ops_that_dont_require_outputs->end(); +} + +bool OpDoesntRequireInput(const string& op_name) { + static tensorflow::gtl::FlatSet* ops_that_dont_require_inputs = + new tensorflow::gtl::FlatSet({ + "Identity", + "Softmax", + "LogSoftmax", + "BiasAdd", + "Relu", + "Elu", + "Selu", + "SparseSoftmaxCrossEntropyWithLogits", + "Neg", + "Inv", + "Reciprocal", + "Sqrt", + "Exp", + "Tanh", + "Sigmoid", + "Real", + "Imag", + "Conj", + "ReadVariableOp", + "VarHandleOp", + "Shape", + }); + + return ops_that_dont_require_inputs->find(op_name) != + ops_that_dont_require_inputs->end(); +} + +PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs, + PyObject* results, PyObject* name) { + std::vector input_ids = MakeTensorIDList(inputs); + if (PyErr_Occurred()) return nullptr; + + bool should_record = false; + for (TFE_Py_Tape* tape : SafeTapeSet()) { + if (tape->tape->ShouldRecord(input_ids)) { + should_record = true; + break; + } + } + + if (!should_record) Py_RETURN_NONE; + + string c_op_name = TFE_GetPythonString(op_name); + PyObject* op_outputs; + if (OpDoesntRequireOutput(c_op_name)) { + op_outputs = Py_None; + } else { + op_outputs = results; + } + + PyObject* op_inputs; + if (OpDoesntRequireInput(c_op_name)) { + op_inputs = Py_None; + } else { + op_inputs = inputs; + } + + PyObject* num_inputs = PyLong_FromLong(PySequence_Size(inputs)); + PyObject* callback_args = + Py_BuildValue("OOOOO", op_name, attrs, num_inputs, op_inputs, op_outputs); + + PyObject* backward_function = + PyObject_CallObject(backward_function_getter, callback_args); + Py_DECREF(callback_args); + if (backward_function == nullptr) return nullptr; + + TapeSetRecordOperation(op_name, results, input_ids, backward_function); + + Py_DECREF(backward_function); + + Py_RETURN_NONE; +} + bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, const tensorflow::OpDef* op_def, PyObject* args, const std::vector& flattened_inputs, @@ -1471,21 +1657,7 @@ bool RunCallbacks(bool run_gradient_callback, bool run_post_exec_callbacks, }); if (run_gradient_callback) { - if (!PyCallable_Check(record_gradient_callback)) { - PyErr_SetString(PyExc_TypeError, - Printf("expected a function for " - "record_gradient_callback, got %s instead", - record_gradient_callback->ob_type->tp_name) - .c_str()); - return false; - } - - PyObject* callback_result = - PyObject_CallObject(record_gradient_callback, callback_args); - if (!callback_result) { - return false; - } - Py_DECREF(callback_result); + RecordGradient(op_name, inputs, attrs, flattened_result, name); } if (run_post_exec_callbacks) { @@ -1796,3 +1968,13 @@ PyObject* TFE_Py_FastPathExecute_C(PyObject*, PyObject* args) { Py_DECREF(flat_result); return result; } + +PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs, + PyObject* attrs, PyObject* results, + PyObject* name) { + if (*ThreadTapeIsStopped() || GetTapeSet()->empty()) { + Py_RETURN_NONE; + } + + return RecordGradient(op_name, inputs, attrs, results, name); +} diff --git a/tensorflow/python/estimator/canned/baseline_test.py b/tensorflow/python/estimator/canned/baseline_test.py index 96639e88ea4a07e14121049d78f07e03fcb22156..18c955f5a0e998de983b31fc4cc595895e6bbcbd 100644 --- a/tensorflow/python/estimator/canned/baseline_test.py +++ b/tensorflow/python/estimator/canned/baseline_test.py @@ -1075,7 +1075,7 @@ class BaselineClassifierEvaluationTest(test.TestCase): metric_keys.MetricKeys.LABEL_MEAN: 1., metric_keys.MetricKeys.ACCURACY_BASELINE: 1, metric_keys.MetricKeys.AUC: 0., - metric_keys.MetricKeys.AUC_PR: 1., + metric_keys.MetricKeys.AUC_PR: 0.5, } else: # Multi classes: loss = 1 * -log ( softmax(logits)[label] ) @@ -1136,7 +1136,7 @@ class BaselineClassifierEvaluationTest(test.TestCase): metric_keys.MetricKeys.LABEL_MEAN: 0.5, metric_keys.MetricKeys.ACCURACY_BASELINE: 0.5, metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.75, + metric_keys.MetricKeys.AUC_PR: 0.25, } else: # Expand logits since batch_size=2 @@ -1212,7 +1212,7 @@ class BaselineClassifierEvaluationTest(test.TestCase): metric_keys.MetricKeys.ACCURACY_BASELINE: ( max(label_mean, 1-label_mean)), metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 2. / (1. + 2.), + metric_keys.MetricKeys.AUC_PR: 0.16666645, } else: # Multi classes: unweighted_loss = 1 * -log ( soft_max(logits)[label] ) diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py index c29b5cabc7f04994365948dffec78382a871ff40..7043da8de036e5be27d223271c37e065d9ffbcdd 100644 --- a/tensorflow/python/estimator/canned/dnn.py +++ b/tensorflow/python/estimator/canned/dnn.py @@ -150,9 +150,7 @@ def _dnn_model_fn(features, config: `RunConfig` object to configure the runtime settings. Returns: - predictions: A dict of `Tensor` objects. - loss: A scalar containing the loss of the step. - train_op: The op for training. + An `EstimatorSpec` instance. Raises: ValueError: If features has the wrong type. diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py index 0c54013a5240e6d19c4774958a3a54e9563cbc47..6d0fb96057ee93964ee3571bae3b878faad88882 100644 --- a/tensorflow/python/estimator/canned/dnn_linear_combined.py +++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py @@ -117,7 +117,7 @@ def _dnn_linear_combined_model_fn(features, config: `RunConfig` object to configure the runtime settings. Returns: - `ModelFnOps` + An `EstimatorSpec` instance. Raises: ValueError: If both `linear_feature_columns` and `dnn_features_columns` diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py index 706575985ff9e0fef94f110825ec11af33031ea3..cbae43e4f7fef0271de20a4ec54449989455d4bd 100644 --- a/tensorflow/python/estimator/canned/dnn_testing_utils.py +++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py @@ -1041,7 +1041,7 @@ class BaseDNNClassifierEvaluateTest(object): # There is no good way to calculate AUC for only two data points. But # that is what the algorithm returns. metric_keys.MetricKeys.AUC: 0.5, - metric_keys.MetricKeys.AUC_PR: 0.75, + metric_keys.MetricKeys.AUC_PR: 0.25, ops.GraphKeys.GLOBAL_STEP: global_step }, dnn_classifier.evaluate(input_fn=_input_fn, steps=1)) diff --git a/tensorflow/python/estimator/canned/head.py b/tensorflow/python/estimator/canned/head.py index cb9e3fc6ca116ac0f48a37cea92fa4119754f324..8d742a2c6147e86619d4c0aad59b69459384bd4d 100644 --- a/tensorflow/python/estimator/canned/head.py +++ b/tensorflow/python/estimator/canned/head.py @@ -1156,6 +1156,7 @@ def _regression_head_with_mean_squared_error_loss( label_dimension=1, loss_reduction=losses.Reduction.SUM, loss_fn=None, + inverse_link_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -1174,10 +1175,16 @@ def _regression_head_with_mean_squared_error_loss( `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. - Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + Supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or `(labels, logits, features)` as arguments and returns unreduced loss with shape `[D0, D1, ... DN, label_dimension]`. + Also supports custom `inverse_link_fn`, also known as 'mean function'. + `inverse_link_fn` takes `logits` as argument and returns predicted values. + This function is the inverse of the link function defined in + https://en.wikipedia.org/wiki/Generalized_linear_model#Link_function + Namely, for poisson regression, set `inverse_link_fn=tf.exp`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -1188,7 +1195,9 @@ def _regression_head_with_mean_squared_error_loss( `[batch_size, label_dimension]`). loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM`. - loss_fn: Optional loss function. + loss_fn: Optional loss function. Defaults to `mean_squared_error`. + inverse_link_fn: Optional inverse link function, also known as 'mean + function'. Defaults to identity. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -1208,6 +1217,7 @@ def _regression_head_with_mean_squared_error_loss( label_dimension=label_dimension, loss_reduction=loss_reduction, loss_fn=loss_fn, + inverse_link_fn=inverse_link_fn, name=name) @@ -1220,6 +1230,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): weight_column=None, loss_reduction=losses.Reduction.SUM, loss_fn=None, + inverse_link_fn=None, name=None): """`Head` for regression.""" if label_dimension < 1: @@ -1228,6 +1239,7 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): self._weight_column = weight_column self._loss_reduction = loss_reduction self._loss_fn = loss_fn + self._inverse_link_fn = inverse_link_fn self._name = name @property @@ -1294,9 +1306,19 @@ class _RegressionHeadWithMeanSquaredErrorLoss(_Head): # Predict. with ops.name_scope(self._name, 'head'): logits = _check_logits_final_dim(logits, self._logits_dimension) - predictions = {prediction_keys.PredictionKeys.PREDICTIONS: logits} + if self._inverse_link_fn: + predicted_value = self._inverse_link_fn(logits) + predictions = { + prediction_keys.PredictionKeys.PREDICTIONS: predicted_value, + prediction_keys.PredictionKeys.LOGITS: logits, + } + else: + predicted_value = logits + predictions = { + prediction_keys.PredictionKeys.PREDICTIONS: predicted_value} if mode == model_fn.ModeKeys.PREDICT: - regression_output = export_output.RegressionOutput(value=logits) + regression_output = export_output.RegressionOutput( + value=predicted_value) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, diff --git a/tensorflow/python/estimator/canned/head_test.py b/tensorflow/python/estimator/canned/head_test.py index 3a03770af498981a054c3df9155e83a60c7f0350..a300f315c18f60e77f262a3b961c5ef6306bc235 100644 --- a/tensorflow/python/estimator/canned/head_test.py +++ b/tensorflow/python/estimator/canned/head_test.py @@ -1558,7 +1558,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 1., + keys.AUC_PR: 0.74999905, } # Assert spec contains expected tensors. @@ -1636,7 +1636,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 1., + keys.AUC_PR: 0.75, } # Assert predictions, loss, and metrics. @@ -1741,7 +1741,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: 2./2, keys.ACCURACY_BASELINE: 2./2, keys.AUC: 0., - keys.AUC_PR: 1., + keys.AUC_PR: 0.74999905, keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 1., keys.PRECISION_AT_THRESHOLD % thresholds[0]: 1., keys.RECALL_AT_THRESHOLD % thresholds[0]: 1., @@ -2188,7 +2188,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): keys.LABEL_MEAN: expected_label_mean, keys.ACCURACY_BASELINE: 1 - expected_label_mean, keys.AUC: .45454565, - keys.AUC_PR: .6737757325172424, + keys.AUC_PR: .21923049, } # Assert spec contains expected tensors. @@ -2487,7 +2487,7 @@ class BinaryLogisticHeadWithSigmoidCrossEntropyLossTest(test.TestCase): # We cannot reliably calculate AUC with only 4 data points, but the # values should not change because of backwards-compatibility. keys.AUC: 0.5222, - keys.AUC_PR: 0.7341, + keys.AUC_PR: 0.5119, } tol = 1e-2 @@ -2703,10 +2703,9 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): self.assertIsNone(spec.loss) self.assertEqual({}, spec.eval_metric_ops) self.assertIsNone(spec.train_op) + default_serving_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY self.assertItemsEqual( - (signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, - 'predict', - 'regression'), + (default_serving_key, 'predict', 'regression'), spec.export_outputs.keys()) _assert_no_hooks(self, spec) @@ -2714,6 +2713,54 @@ class RegressionHeadWithMeanSquaredErrorLossTest(test.TestCase): with self.test_session(): _initialize_variables(self, spec.scaffold) self.assertAllClose(logits, spec.predictions[prediction_key].eval()) + self.assertAllClose( + logits, spec.export_outputs[default_serving_key].value.eval()) + self.assertAllClose( + logits, spec.export_outputs['regression'].value.eval()) + self.assertAllClose( + logits, spec.export_outputs['predict'].outputs['predictions'].eval()) + + def test_predict_with_inverse_link_fn(self): + def _inverse_link_fn(logits): + return logits - 10. + head = head_lib._regression_head_with_mean_squared_error_loss( + inverse_link_fn=_inverse_link_fn) + + # Create estimator spec. + logits = np.array(((45,), (41,),), dtype=np.int32) + expected_predictions = np.array(((35,), (31,),), dtype=np.int32) + spec = head.create_estimator_spec( + features={'x': np.array(((42.,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + # Assert spec contains expected tensors. + keys = prediction_keys.PredictionKeys + self.assertItemsEqual( + (keys.PREDICTIONS, keys.LOGITS), spec.predictions.keys()) + self.assertEqual(dtypes.float32, spec.predictions[keys.PREDICTIONS].dtype) + self.assertEqual(dtypes.float32, spec.predictions[keys.LOGITS].dtype) + default_serving_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + self.assertItemsEqual( + (default_serving_key, 'predict', 'regression'), + spec.export_outputs.keys()) + + # Assert predictions. + with self.test_session(): + _initialize_variables(self, spec.scaffold) + self.assertAllClose( + expected_predictions, spec.predictions[keys.PREDICTIONS].eval()) + self.assertAllClose(logits, spec.predictions[keys.LOGITS].eval()) + self.assertAllClose( + expected_predictions, + spec.export_outputs[default_serving_key].value.eval()) + self.assertAllClose( + expected_predictions, spec.export_outputs['regression'].value.eval()) + self.assertAllClose( + expected_predictions, + spec.export_outputs['predict'].outputs['predictions'].eval()) + self.assertAllClose( + logits, spec.export_outputs['predict'].outputs['logits'].eval()) def test_eval_create_loss(self): head = head_lib._regression_head_with_mean_squared_error_loss() diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py index 3e9183cf1b633757074377472e9b4cac953e04a1..e88fcbbd2e0e3617dde428662e58b1d86c4eddd0 100644 --- a/tensorflow/python/estimator/canned/linear_testing_utils.py +++ b/tensorflow/python/estimator/canned/linear_testing_utils.py @@ -1342,7 +1342,7 @@ class BaseLinearClassifierEvaluationTest(object): metric_keys.MetricKeys.LABEL_MEAN: 1., metric_keys.MetricKeys.ACCURACY_BASELINE: 1, metric_keys.MetricKeys.AUC: 0., - metric_keys.MetricKeys.AUC_PR: 1., + metric_keys.MetricKeys.AUC_PR: 0.5, } else: # Multi classes: loss = 1 * -log ( soft_max(logits)[label] ) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 7bf838e5a0ccc10cc6cf8dd5b18a44565c920d46..1167b3834eb6a79abf670f629ec2cbc37957d191 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -427,7 +427,8 @@ class Estimator(object): input_fn, predict_keys=None, hooks=None, - checkpoint_path=None): + checkpoint_path=None, + yield_single_examples=True): """Yields predictions for given features. Args: @@ -453,13 +454,18 @@ class Estimator(object): inside the prediction call. checkpoint_path: Path of a specific checkpoint to predict. If `None`, the latest checkpoint in `model_dir` is used. + yield_single_examples: If False, yield the whole batch as returned by the + model_fn instead of decomposing the batch into individual elements. This + is useful if model_fn return some tensor with first dimension not + equal to the batch size Yields: Evaluated values of `predictions` tensors. Raises: ValueError: Could not find a trained model in model_dir. - ValueError: if batch length of predictions are not same. + ValueError: if batch length of predictions are not same and + yield_single_examples is True. ValueError: If there is a conflict between `predict_keys` and `predictions`. For example if `predict_keys` is not `None` but `EstimatorSpec.predictions` is not a `dict`. @@ -492,7 +498,9 @@ class Estimator(object): hooks=all_hooks) as mon_sess: while not mon_sess.should_stop(): preds_evaluated = mon_sess.run(predictions) - if not isinstance(predictions, dict): + if not yield_single_examples: + yield preds_evaluated + elif not isinstance(predictions, dict): for pred in preds_evaluated: yield pred else: @@ -1106,7 +1114,7 @@ def _write_dict_to_summary(output_dir, isinstance(dictionary[key], np.int32) or isinstance(dictionary[key], int)): summary_proto.value.add(tag=key, simple_value=int(dictionary[key])) - elif isinstance(dictionary[key], six.string_types): + elif isinstance(dictionary[key], six.binary_type): try: summ = summary_pb2.Summary.FromString(dictionary[key]) for i, _ in enumerate(summ.value): diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 39a5b998ebdcccfbeddf0fc96dab44dc91a289fa..7a0745b1d0d5ae932fa59be56a4952e82922a584 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -80,18 +80,18 @@ def dummy_model_fn(features, labels, params): _, _, _ = features, labels, params -def check_eventfile_for_keyword(keyword, est): +def check_eventfile_for_keyword(keyword, dir_): """Checks event files for the keyword.""" writer_cache.FileWriterCache.clear() # Get last Event written. - event_paths = glob.glob(os.path.join(est.model_dir, 'events*')) + event_paths = glob.glob(os.path.join(dir_, 'events*')) last_event = None for last_event in summary_iterator.summary_iterator(event_paths[-1]): if last_event.summary is not None: - if last_event.summary.value: - if keyword in last_event.summary.value[0].tag: + for value in last_event.summary.value: + if keyword in value.tag: return True return False @@ -610,7 +610,7 @@ class EstimatorTrainTest(test.TestCase): # Make sure nothing is stuck in limbo. writer_cache.FileWriterCache.clear() - if check_eventfile_for_keyword('loss', est): + if check_eventfile_for_keyword('loss', est.model_dir): return self.fail('{} should be part of reported summaries.'.format('loss')) @@ -1290,8 +1290,9 @@ class EstimatorEvaluateTest(test.TestCase): # Make sure nothing is stuck in limbo. writer_cache.FileWriterCache.clear() - # Get last Event written. - if check_eventfile_for_keyword('image', est): + # Get last evaluation Event written. + if check_eventfile_for_keyword('image', os.path.join(est.model_dir, + 'eval')): return self.fail('{} should be part of reported summaries.'.format('image')) @@ -1472,6 +1473,27 @@ class EstimatorPredictTest(test.TestCase): 'Batch length of predictions should be same'): next(est.predict(dummy_input_fn)) + def test_iterate_batches(self): + + def _model_fn(features, labels, mode): + _, _ = features, labels + return model_fn_lib.EstimatorSpec( + mode, + loss=constant_op.constant(0.), + train_op=state_ops.assign_add(training.get_global_step(), 1), + predictions={ + # First dim is different but the prediction should still work + 'y1': array_ops.zeros(shape=[3]), + 'y2': array_ops.zeros(shape=[5, 3]) + }) + + est = estimator.Estimator(model_fn=_model_fn) + est.train(dummy_input_fn, steps=1) + + predictions = next(est.predict(dummy_input_fn, yield_single_examples=False)) + self.assertAllEqual(predictions['y1'].shape, [3]) + self.assertAllEqual(predictions['y2'].shape, [5, 3]) + def test_predict_keys_defined_for_tensor(self): def _model_fn(features, labels, mode): diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 63328dcfb55646ce2aaf8929d5517c8522c418f2..2cc3331a15867e9a984847391857bf84baee7424 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -455,15 +455,21 @@ class _TrainingExecutor(object): train_hooks=None, continuous_eval_listener=None): if not isinstance(estimator, estimator_lib.Estimator): - raise TypeError('`estimator` must have type `tf.estimator.Estimator`.') + raise TypeError( + '`estimator` must have type `tf.estimator.Estimator`. ' + 'Got: {}'.format(type(estimator))) self._estimator = estimator if not isinstance(train_spec, TrainSpec): - raise TypeError('`train_spec` must have type `tf.estimator.TrainSpec`.') + raise TypeError( + '`train_spec` must have type `tf.estimator.TrainSpec`. ' + 'Got: {}'.format(type(train_spec))) self._train_spec = train_spec if not isinstance(eval_spec, EvalSpec): - raise TypeError('`eval_spec` must have type `tf.estimator.EvalSpec`.') + raise TypeError( + '`eval_spec` must have type `tf.estimator.EvalSpec`. ' + 'Got: {}'.format(type(eval_spec))) self._eval_spec = eval_spec self._train_hooks = _validate_hooks(train_hooks) diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py index b7ba76d8714e6b13551bb3e18083f45e53d2afc3..3ce8eea84b6bf601ce89dfaa7d8e3a5d193468b3 100644 --- a/tensorflow/python/estimator/util.py +++ b/tensorflow/python/estimator/util.py @@ -21,10 +21,12 @@ from __future__ import print_function import functools +from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect def _is_bounded_method(fn): + _, fn = tf_decorator.unwrap(fn) return tf_inspect.ismethod(fn) and (fn.__self__ is not None) diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py index 52e42ef0188115750a3712f4fe07976a456e61e2..c416881c3119c160d28f4b8e37cd2aeb22f239a6 100644 --- a/tensorflow/python/feature_column/feature_column.py +++ b/tensorflow/python/feature_column/feature_column.py @@ -512,6 +512,7 @@ def make_parse_example_spec(feature_columns): ```python # Define features and transformations + feature_a = categorical_column_with_vocabulary_file(...) feature_b = numeric_column(...) feature_c_bucketized = bucketized_column(numeric_column("feature_c"), ...) feature_a_x_feature_c = crossed_column( diff --git a/tensorflow/python/framework/common_shapes.py b/tensorflow/python/framework/common_shapes.py index 3b1092f923112dbd9a081942d40162ae384bf167..3c5aebbce8af117aa1e216f1ef07ded181c997ea 100644 --- a/tensorflow/python/framework/common_shapes.py +++ b/tensorflow/python/framework/common_shapes.py @@ -34,7 +34,7 @@ def scalar_shape(unused_op): def unchanged_shape(op): - """Shape function for ops that output an tensor like their first input.""" + """Shape function for ops that output a tensor like their first input.""" return [op.inputs[0].get_shape()] diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index cba225e749d88a45c43266e45172a7335a8e0b71..caa604999c2fad4ce111d910a77e4b99399c11ca 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -353,8 +353,10 @@ class _DefinedFunction(object): outputs = (outputs,) if any([_ is None for _ in outputs]): raise ValueError("Function can not return None.") - # Ensures each output is a Tensor. - outputs = [ops.convert_to_tensor(_) for _ in outputs] + # Ensures each output is a Tensor in the function graph. + outputs = [ops.convert_to_tensor(t) for t in outputs] + outputs = [temp_graph.capture(t) if t.graph is not temp_graph else t + for t in outputs] self._extra_inputs = temp_graph.extra_inputs inputs.extend(temp_graph.extra_args) # pylint: disable=protected-access @@ -683,28 +685,34 @@ class _FuncGraph(ops.Graph): def create_op(self, op_type, inputs, data_types, **kwargs): for i, x in enumerate(inputs): if isinstance(x, ops.EagerTensor) or x.graph is not self: - # Referring to a tensor from other graph. - if x in self._captured: - # Captured already. - inputs[i] = self._captured[x] - elif self._capture_by_value: - inputs[i] = self._add_tensor_and_parents(x) - else: - # Substitute with a placeholder. - self.extra_inputs.append(x) - # Hoist the new input placeholder out of any control flow context - # we're currently in. - with ops.control_dependencies(None): - ph = array_ops.placeholder(x.dtype, shape=x.get_shape()) - # pylint: disable=protected-access - ph._handle_data = x._handle_data - # pylint: enable=protected-access - inputs[i] = ph - self._captured[x] = ph - self.extra_args.append(ph) + inputs[i] = self.capture(x) return super(_FuncGraph, self).create_op(op_type, inputs, data_types, **kwargs) + def capture(self, tensor): + """Adds the given tensor to this graph and returns the captured tensor.""" + if tensor in self._captured: + # Captured already. + return self._captured[tensor] + elif self._capture_by_value: + return self._add_tensor_and_parents(tensor) + else: + return self._capture_tensor_as_extra_input(tensor) + + def _capture_tensor_as_extra_input(self, tensor): + # Substitute with a placeholder. + self.extra_inputs.append(tensor) + # Hoist the new input placeholder out of any control flow context + # we're currently in. + with ops.control_dependencies(None): + ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape()) + # pylint: disable=protected-access + ph._handle_data = tensor._handle_data + # pylint: enable=protected-access + self._captured[tensor] = ph + self.extra_args.append(ph) + return ph + def _add_tensor_and_parents(self, tensor): op = self._add_op_and_parents(tensor.op) return op.outputs[tensor.value_index] diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 301a7f682dde8dbeccd1e81675b0059433990a09..52052ba77d42fa91692e7699f49898d0c01c22be 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -725,9 +725,16 @@ class FunctionTest(test.TestCase): y = Foo(constant_op.constant([[10.]])) + @function.Defun() + def Bar(): + return w + + z = Bar() + with self.test_session(graph=g): variables.global_variables_initializer().run() self.assertAllEqual(y.eval(), [[12.0]]) + self.assertAllEqual(z.eval(), [[1.0]]) def testCaptureControls(self): g = ops.Graph() diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py index 8c03a5f19dee31a6609590e46d608af9a686c5fe..4c1bd736d727e974375ad9008a579361137fb9d6 100644 --- a/tensorflow/python/framework/meta_graph.py +++ b/tensorflow/python/framework/meta_graph.py @@ -741,6 +741,7 @@ def import_scoped_meta_graph(meta_graph_or_file, producer_op_list=producer_op_list) # Restores all the other collections. + variable_objects = {} for key, col_def in sorted(meta_graph_def.collection_def.items()): # Don't add unbound_inputs to the new graph. if key == unbound_inputs_col_name: @@ -756,11 +757,23 @@ def import_scoped_meta_graph(meta_graph_or_file, from_proto = ops.get_from_proto_function(key) if from_proto and kind == "bytes_list": proto_type = ops.get_collection_proto_type(key) - for value in col_def.bytes_list.value: - proto = proto_type() - proto.ParseFromString(value) - graph.add_to_collection( - key, from_proto(proto, import_scope=scope_to_prepend_to_names)) + if key in ops.GraphKeys._VARIABLE_COLLECTIONS: # pylint: disable=protected-access + for value in col_def.bytes_list.value: + variable = variable_objects.get(value, None) + if variable is None: + proto = proto_type() + proto.ParseFromString(value) + variable = from_proto( + proto, import_scope=scope_to_prepend_to_names) + variable_objects[value] = variable + graph.add_to_collection(key, variable) + else: + for value in col_def.bytes_list.value: + proto = proto_type() + proto.ParseFromString(value) + graph.add_to_collection( + key, from_proto( + proto, import_scope=scope_to_prepend_to_names)) else: field = getattr(col_def, kind) if key in _COMPAT_COLLECTION_LIST: diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py index f2f1e83da15eacdbb4f194967b51559d279ae1a4..19dcd6a1b34741290b2578d93b79883c103fdb1b 100644 --- a/tensorflow/python/framework/meta_graph_test.py +++ b/tensorflow/python/framework/meta_graph_test.py @@ -261,6 +261,29 @@ class SimpleMetaGraphTest(test.TestCase): self.assertEqual(node_def.attr["attr_1"].i, 1) self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs) + def testVariableObjectsAreSharedAmongCollections(self): + with ops.Graph().as_default() as graph1: + v = variables.Variable(3.0) + # A single instance of Variable is shared among the collections: + global_vars = graph1.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + trainable_vars = graph1.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual(len(global_vars), 1) + self.assertEqual(len(trainable_vars), 1) + self.assertIs(global_vars[0], trainable_vars[0]) + self.assertIs(v, global_vars[0]) + + orig_meta_graph, _ = meta_graph.export_scoped_meta_graph(graph=graph1) + del graph1 # To avoid accidental references in code involving graph2. + + with ops.Graph().as_default() as graph2: + meta_graph.import_scoped_meta_graph(orig_meta_graph) + global_vars = graph2.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + trainable_vars = graph2.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + self.assertEqual(len(global_vars), 1) + self.assertEqual(len(trainable_vars), 1) + # A single instance of Variable is shared among the collections: + self.assertIs(global_vars[0], trainable_vars[0]) + @test_util.with_c_api class ScopedMetaGraphTest(test.TestCase): @@ -883,21 +906,25 @@ class ExportImportAcrossScopesTest(test.TestCase): graph_fn(use_resource=use_resource) if use_resource: - # Bringing in a collection that contains ResourceVariables adds ops - # to the graph, so mimic the same behavior. + # Bringing in collections that contain ResourceVariables will adds ops + # to the graph the first time a variable is encountered, so mimic the + # same behavior. + seen_variables = set() for collection_key in sorted([ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES, ]): for var in expected_graph.get_collection(collection_key): - var._read_variable_op() + if var not in seen_variables: + var._read_variable_op() + seen_variables.add(var) result = meta_graph.export_scoped_meta_graph(graph=imported_graph)[0] expected = meta_graph.export_scoped_meta_graph(graph=expected_graph)[0] if use_resource: # Clear all shared_name attributes before comparing, since they are - # supposed to be orthogonal to scopes. + # orthogonal to scopes and are not updated on export/import. for meta_graph_def in [result, expected]: for node in meta_graph_def.graph_def.node: shared_name_attr = "shared_name" diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 77e83554c99b6abcceca908856dc6b1cdbce98b3..b0d2704c0747c66acb1af987ea3d2943d98169f0 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -368,8 +368,8 @@ class Tensor(_TensorLike): A `TensorShape` representing the shape of this tensor. """ - if _USE_C_API: - graph = self._op._graph._c_graph # pylint: disable=protected-access + graph = self._op._graph._c_graph # pylint: disable=protected-access + if graph: with errors.raise_exception_on_not_ok_status() as status: num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(), status) @@ -466,7 +466,7 @@ class Tensor(_TensorLike): ValueError: If `shape` is not compatible with the current shape of this tensor. """ - if not _USE_C_API: + if not self._op._graph._c_graph: # pylint: disable=protected-access # ASIM self._shape_val = self._shape_val.merge_with(shape) return if not isinstance(shape, tensor_shape.TensorShape): @@ -2707,15 +2707,21 @@ class Graph(object): self._name_stack = "" # Maps a name used in the graph to the next id to use for that name. self._names_in_use = {} + self._stack_state_is_thread_local = False + self._thread_local = threading.local() # Functions that will be applied to choose a device if none is specified. - self._device_function_stack = [] + # After switch_to_thread_local(), self._thread_local._device_function_stack + # is used instead. + self._graph_device_function_stack = [] # Default original_op applied to new ops. self._default_original_op = None # Current control flow context. It could be either CondContext or # WhileContext defined in ops/control_flow_ops.py self._control_flow_context = None # A new node will depend of the union of all of the nodes in the stack. - self._control_dependencies_stack = [] + # After switch_to_thread_local(), + # self._thread_local._control_dependencies_stack is used instead. + self._graph_control_dependencies_stack = [] # Arbitrary collections of objects. self._collections = {} # The graph-level random seed @@ -2737,8 +2743,9 @@ class Graph(object): producer=versions.GRAPH_DEF_VERSION, min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER) self._building_function = False - # Stack of colocate_with ops - self._colocation_stack = [] + # Stack of colocate_with ops. After switch_to_thread_local(), + # self._thread_local._colocation_stack is used instead. + self._graph_colocation_stack = [] # Set of tensors that are dangerous to feed! self._unfeedable_tensors = set() # Set of operations that are dangerous to fetch! @@ -2761,8 +2768,12 @@ class Graph(object): # TODO(skyewm): fold as much of the above as possible into the C # implementation - if _USE_C_API or self._use_c_api_hack(): + if self._use_c_api_hack(): self._scoped_c_graph = c_api_util.ScopedTFGraph() + # The C API requires all ops to have shape functions. Disable this + # requirement (many custom ops do not have shape functions, and we don't + # want to break these existing cases). + c_api.SetRequireShapeInferenceFns(self._c_graph, False) else: self._scoped_c_graph = None self._variable_creator_stack = [] @@ -2770,7 +2781,7 @@ class Graph(object): # TODO(apassos) remove once the C API is used by default. def _use_c_api_hack(self): """Temporary hack; can be overridden to force C API usage.""" - return False + return _USE_C_API def _convert_stack(self, stack, include_func_start_lineno=False): """Converts a stack extracted using _extract_stack() to a traceback stack. @@ -3030,7 +3041,7 @@ class Graph(object): """ # pylint: enable=line-too-long - if _USE_C_API: + if self._c_graph: with self._lock: with c_api_util.tf_buffer() as buf: with errors.raise_exception_on_not_ok_status() as status: @@ -3350,9 +3361,9 @@ class Graph(object): if (op.device and pydev.canonical_name(op.device) != pydev.canonical_name(colocation_op.device)): logging.warning("Tried to colocate %s with an op %s that had " - "a different device: %s vs %s. " - "Ignoring colocation property.", op.name, - colocation_op.name, op.device, + "a different device: %s vs %s. Postponing " + "error-checking until all devices are assigned.", + op.name, colocation_op.name, op.device, colocation_op.device) else: op._set_device(colocation_op.device) # pylint: disable=protected-access @@ -4669,6 +4680,79 @@ class Graph(object): else: return tensor_or_op not in self._unfetchable_ops + def switch_to_thread_local(self): + """Make device, colocation and dependencies stacks thread-local. + + Device, colocation and dependencies stacks are not thread-local be default. + If multiple threads access them, then the state is shared. This means that + one thread may affect the behavior of another thread. + + After this method is called, the stacks become thread-local. If multiple + threads access them, then the state is not shared. Each thread uses its own + value; a thread doesn't affect other threads by mutating such a stack. + + The initial value for every thread's stack is set to the current value + of the stack when `switch_to_thread_local()` was first called. + """ + if not self._stack_state_is_thread_local: + self._stack_state_is_thread_local = True + + @property + def _device_function_stack(self): + if self._stack_state_is_thread_local: + # This may be called from a thread where device_function_stack doesn't yet + # exist. + if not hasattr(self._thread_local, "_device_function_stack"): + self._thread_local._device_function_stack = ( + self._graph_device_function_stack[:]) + return self._thread_local._device_function_stack + else: + return self._graph_device_function_stack + + @_device_function_stack.setter + def _device_function_stack(self, device_function_stack): + if self._stack_state_is_thread_local: + self._thread_local._device_function_stack = device_function_stack + else: + self._graph_device_function_stack = device_function_stack + + @property + def _colocation_stack(self): + if self._stack_state_is_thread_local: + # This may be called from a thread where colocation_stack doesn't yet + # exist. + if not hasattr(self._thread_local, "_colocation_stack"): + self._thread_local._colocation_stack = self._graph_colocation_stack[:] + return self._thread_local._colocation_stack + else: + return self._graph_colocation_stack + + @_colocation_stack.setter + def _colocation_stack(self, colocation_stack): + if self._stack_state_is_thread_local: + self._thread_local._colocation_stack = colocation_stack + else: + self._graph_colocation_stack = colocation_stack + + @property + def _control_dependencies_stack(self): + if self._stack_state_is_thread_local: + # This may be called from a thread where control_dependencies_stack + # doesn't yet exist. + if not hasattr(self._thread_local, "_control_dependencies_stack"): + self._thread_local._control_dependencies_stack = ( + self._graph_control_dependencies_stack[:]) + return self._thread_local._control_dependencies_stack + else: + return self._graph_control_dependencies_stack + + @_control_dependencies_stack.setter + def _control_dependencies_stack(self, control_dependencies): + if self._stack_state_is_thread_local: + self._thread_local._control_dependencies_stack = control_dependencies + else: + self._graph_control_dependencies_stack = control_dependencies + # TODO(agarwal): currently device directives in an outer eager scope will not # apply to inner graph mode code. Fix that. @@ -4721,7 +4805,14 @@ def container(container_name): @tf_export("colocate_with") def colocate_with(op, ignore_existing=False): if context.in_graph_mode(): - return get_default_graph().colocate_with(op, ignore_existing) + default_graph = get_default_graph() + if isinstance(op, EagerTensor): + if default_graph.building_function: + op = internal_convert_to_tensor(op) + else: + raise ValueError("Encountered an Eager-defined Tensor during graph " + "construction, but a function was not being built.") + return default_graph.colocate_with(op, ignore_existing) else: if op is not None: return device(op.device) @@ -5537,6 +5628,9 @@ def get_all_collection_keys(): return get_default_graph().get_all_collection_keys() +name_scope_cache = {} + + # Named like a function for backwards compatibility with the # @tf_contextlib.contextmanager version, which was switched to a class to avoid # some object creation overhead. @@ -5596,7 +5690,11 @@ class name_scope(object): # pylint: disable=invalid-name if not self._name: scope_name = "" else: - if self._name[-1] == "/": + cache_key = self._name, self._old_name, self._default_name + if cache_key in name_scope_cache: + self._ctx.scope_name = name_scope_cache[cache_key] + return self._ctx.scope_name + elif self._name[-1] == "/": # A trailing slash breaks out of nested name scopes, indicating a # fully specified scope name, for compatibility with Graph.name_scope. scope_name = self._name @@ -5605,6 +5703,7 @@ class name_scope(object): # pylint: disable=invalid-name scope_name = ( self._old_name + name_with_trailing_slash if self._old_name else name_with_trailing_slash) + name_scope_cache[cache_key] = scope_name self._ctx.scope_name = scope_name return scope_name else: diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index c6deafd89eb1bdc4892a65ba3ab8c7900915390f..a141fe6340c70efde84411db4efb1f80cb0a61c5 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import gc +import threading import weakref from tensorflow.core.framework import attr_value_pb2 @@ -1381,6 +1382,180 @@ class DeviceTest(test_util.TensorFlowTestCase): """, gd) +@test_util.with_c_api +class MultithreadedGraphStateTest(test_util.TensorFlowTestCase): + + class TestThread(threading.Thread): + + def __init__(self, graph, replica_id): + super(MultithreadedGraphStateTest.TestThread, self).__init__() + self._graph = graph + self._replica_id = replica_id + # This thread sets this event when it mutated the graph. The caller can + # wait for that. + self.has_mutated_graph = threading.Event() + # This thread waits for when it should continue. The caller can set this + # event. + self.should_continue = threading.Event() + + def run(self): + # Mutate a graph's stack, then set `has_mutated_graph`, then wait for + # `should_continue`, then add an op to the graph affected by the graph's + # stack. + raise NotImplementedError("must be implemented in descendants") + + def testDeviceFunctionStack(self): + + class DeviceSettingThread(self.TestThread): + + def run(self): + with g.device("/job:worker/replica:{}".format(self._replica_id)): + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="FloatOutput_{}".format(self._replica_id)) + + g = ops.Graph() + # If `switch_to_thread` isn't called, then device placement of the ops + # below is not deterministic. + g.switch_to_thread_local() + threads = [DeviceSettingThread(g, i) for i in range(3)] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + for t in threads: + t.should_continue.set() + t.join() + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "FloatOutput_0" op: "FloatOutput" + device: "/job:worker/replica:0" } + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:1" } + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2" } + """, gd) + + def testColocateWith(self): + + class ColocatingThread(self.TestThread): + + def __init__(self, graph, replica_id, op_to_colocate_with): + super(ColocatingThread, self).__init__(graph, replica_id) + self._op_to_colocate_with = op_to_colocate_with + + def run(self): + with g.colocate_with(self._op_to_colocate_with): + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="FloatOutput_{}".format(self._replica_id)) + + g = ops.Graph() + ops_to_colocate_with = [] + for i in range(3): + with g.device("/job:worker/replica:{}".format(i)): + ops_to_colocate_with.append( + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="ColocateWithMe_{}".format(i))) + + # If `switch_to_thread` isn't called, then `device` and `attr` values for + # the ops below are not deterministic. + g.switch_to_thread_local() + threads = [ + ColocatingThread(g, i, ops_to_colocate_with[i]) for i in range(3) + ] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + for t in threads: + t.should_continue.set() + t.join() + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "ColocateWithMe_0" op: "FloatOutput" + device: "/job:worker/replica:0" } + node { name: "ColocateWithMe_1" op: "FloatOutput" + device: "/job:worker/replica:1" } + node { name: "ColocateWithMe_2" op: "FloatOutput" + device: "/job:worker/replica:2" } + node { name: "FloatOutput_0" op: "FloatOutput" + device: "/job:worker/replica:0" + attr { key: "_class" + value { list { + s: "loc:@ColocateWithMe_0"}}}} + node { name: "FloatOutput_1" op: "FloatOutput" + device: "/job:worker/replica:1" + attr { key: "_class" + value { list { + s: "loc:@ColocateWithMe_1"}}}} + node { name: "FloatOutput_2" op: "FloatOutput" + device: "/job:worker/replica:2" + attr { key: "_class" + value { list { + s: "loc:@ColocateWithMe_2"}}}} + """, gd) + + def testControlDependencies(self): + + class DependingThread(self.TestThread): + + def __init__(self, graph, replica_id, dependency_op): + super(DependingThread, self).__init__(graph, replica_id) + self._dependency_op = dependency_op + + def run(self): + with g.control_dependencies([self._dependency_op]): + self.has_mutated_graph.set() + self.should_continue.wait() + self.should_continue.clear() + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="FloatOutput_{}".format(self._replica_id)) + + g = ops.Graph() + dependency_ops = [] + for i in range(3): + dependency_ops.append( + g.create_op( + "FloatOutput", [], [dtypes.float32], + name="ColocateWithMe_{}".format(i))) + + # If `switch_to_thread` isn't called, then `input` values for the ops below + # are not deterministic. + g.switch_to_thread_local() + threads = [DependingThread(g, i, dependency_ops[i]) for i in range(3)] + for t in threads: + t.start() + t.has_mutated_graph.wait() + t.has_mutated_graph.clear() + for t in threads: + t.should_continue.set() + t.join() + + gd = g.as_graph_def() + self.assertProtoEqualsVersion(""" + node { name: "ColocateWithMe_0" op: "FloatOutput" } + node { name: "ColocateWithMe_1" op: "FloatOutput" } + node { name: "ColocateWithMe_2" op: "FloatOutput" } + node { name: "FloatOutput_0" op: "FloatOutput" + input: "^ColocateWithMe_0" } + node { name: "FloatOutput_1" op: "FloatOutput" + input: "^ColocateWithMe_1" } + node { name: "FloatOutput_2" op: "FloatOutput" + input: "^ColocateWithMe_2" } + """, gd) + + @test_util.with_c_api class ObjectWithName(object): diff --git a/tensorflow/python/framework/smart_cond.py b/tensorflow/python/framework/smart_cond.py new file mode 100644 index 0000000000000000000000000000000000000000..f97bb01f54bbe2a75072e2bc959ae85b86f79dd0 --- /dev/null +++ b/tensorflow/python/framework/smart_cond.py @@ -0,0 +1,79 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""smart_cond and related utilties.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import control_flow_ops + + +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. + + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. + + Arguments: + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. + name: Optional name prefix when using `tf.cond`. + + Returns: + Tensors returned by the call to either `true_fn` or `false_fn`. + + Raises: + TypeError: If `true_fn` or `false_fn` is not callable. + """ + if not callable(true_fn): + raise TypeError("`true_fn` must be callable.") + if not callable(false_fn): + raise TypeError("`false_fn` must be callable.") + + pred_value = smart_constant_value(pred) + if pred_value is not None: + if pred_value: + return true_fn() + else: + return false_fn() + else: + return control_flow_ops.cond(pred, true_fn=true_fn, false_fn=false_fn, + name=name) + + +def smart_constant_value(pred): + """Return the bool value for `pred`, or None if `pred` had a dynamic value. + + Arguments: + pred: A scalar, either a Python bool or tensor. + + Returns: + True or False if `pred` has a constant boolean value, None otherwise. + + Raises: + TypeError: If `pred` is not a Tensor or bool. + """ + if isinstance(pred, bool): + pred_value = pred + elif isinstance(pred, ops.Tensor): + pred_value = tensor_util.constant_value(pred) + else: + raise TypeError("`pred` must be a Tensor or a Python bool.") + return pred_value diff --git a/tensorflow/python/framework/smart_cond_test.py b/tensorflow/python/framework/smart_cond_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b682506da057af9a645f7f71301564268ed3b20d --- /dev/null +++ b/tensorflow/python/framework/smart_cond_test.py @@ -0,0 +1,66 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond +from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + + +@test_util.with_c_api +class SmartCondTest(test_util.TensorFlowTestCase): + + def testSmartCondTrue(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(2) + y = constant_op.constant(5) + z = smart_cond.smart_cond(True, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 5)) + self.assertEqual(z.eval(), 32) + + def testSmartCondFalse(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(4) + y = constant_op.constant(3) + z = smart_cond.smart_cond(False, lambda: math_ops.multiply(x, 16), + lambda: math_ops.multiply(y, 3)) + self.assertEqual(z.eval(), 9) + + def testSmartCondMissingArg1(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + smart_cond.smart_cond(True, false_fn=lambda: x) + + def testSmartCondMissingArg2(self): + with ops.Graph().as_default(): + with session.Session(): + x = constant_op.constant(1) + with self.assertRaises(TypeError): + smart_cond.smart_cond(True, lambda: x) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py index 0e5f696111ae7f74b41f8af21a5190fc2617e51a..27afaa074a6becd5c8b7db94be59e8da1611c13a 100644 --- a/tensorflow/python/framework/tensor_util.py +++ b/tensorflow/python/framework/tensor_util.py @@ -557,7 +557,8 @@ def MakeNdarray(tensor): dtype = tensor_dtype.as_numpy_dtype if tensor.tensor_content: - return np.fromstring(tensor.tensor_content, dtype=dtype).reshape(shape) + return (np.frombuffer(tensor.tensor_content, dtype=dtype).copy() + .reshape(shape)) elif tensor_dtype == dtypes.float16: # the half_val field of the TensorProto stores the binary representation # of the fp16: we need to reinterpret this as a proper float16 diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py index f2de69e159646b4a085645fa1bfef7782e78cd59..bea0ee34fd4900cc9d4d5d52348ba4512368e81f 100644 --- a/tensorflow/python/framework/tensor_util_test.py +++ b/tensorflow/python/framework/tensor_util_test.py @@ -199,6 +199,25 @@ class TensorUtilTest(test.TestCase): dtype=nptype), a) + def testFloatMutateArray(self): + t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32) + a = tensor_util.MakeNdarray(t) + a[0] = 5.0 + self.assertEquals(np.float32, a.dtype) + self.assertAllClose(np.array([5.0, 20.0, 30.0], dtype=np.float32), a) + if sys.byteorder == "big": + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "A \000\000A\240\000\000A\360\000\000" + """, t) + else: + self.assertProtoEquals(""" + dtype: DT_FLOAT + tensor_shape { dim { size: 3 } } + tensor_content: "\000\000 A\000\000\240A\000\000\360A" + """, t) + def testHalf(self): t = tensor_util.make_tensor_proto(np.array([10.0, 20.0], dtype=np.float16)) self.assertProtoEquals(""" diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index f7f25990f8a31ff6476fc74a21adfd7b3f57a3c9..7389730d91cf9fd35c861ad85040c79108e5eb77 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -419,6 +419,11 @@ def with_c_api(cls): Returns: cls with new test methods added """ + # If the C API is already enabled, don't do anything. Some tests break if the + # same test is run twice, so this allows us to turn on the C API by default + # without breaking these tests. + if ops._USE_C_API: return cls + for name, value in cls.__dict__.copy().items(): if callable(value) and name.startswith("test"): setattr(cls, name + "WithCApi", enable_c_api(value)) @@ -463,8 +468,7 @@ def assert_no_new_tensors(f): f(self, **kwargs) # Make an effort to clear caches, which would otherwise look like leaked # Tensors. - backprop._last_zero = [None] - backprop._shape_dtype = [None, None] + backprop._zeros_cache.flush() context.get_default_context().scalar_cache().clear() gc.collect() tensors_after = [ @@ -502,6 +506,30 @@ def assert_no_garbage_created(f): previous_garbage = len(gc.garbage) f(self, **kwargs) gc.collect() + if len(gc.garbage) > previous_garbage: + logging.error( + "The decorated test created work for Python's garbage collector, " + "likely due to a reference cycle. New objects in cycle(s):") + for i, obj in enumerate(gc.garbage[previous_garbage:]): + try: + logging.error( + "Object %d of %d" % (i, len(gc.garbage) - previous_garbage)) + def _safe_object_str(obj): + return "<%s %d>" % (obj.__class__.__name__, id(obj)) + logging.error(" Object type: %s" % (_safe_object_str(obj),)) + logging.error(" Referrer types: %s" % ( + ', '.join([_safe_object_str(ref) + for ref in gc.get_referrers(obj)]),)) + logging.error(" Referent types: %s" % ( + ', '.join([_safe_object_str(ref) + for ref in gc.get_referents(obj)]),)) + logging.error(" Object attribute names: %s" % (dir(obj),)) + logging.error(" Object __str__:") + logging.error(obj) + logging.error(" Object __repr__:") + logging.error(repr(obj)) + except Exception: + logging.error("(Exception while printing object)") # This will fail if any garbage has been created, typically because of a # reference cycle. self.assertEqual(previous_garbage, len(gc.garbage)) @@ -560,6 +588,7 @@ def run_in_graph_and_eager_modes(__unused__=None, # This decorator runs the wrapped test twice. # Reset the test environment between runs. self.tearDown() + self._tempdir = None self.setUp() def run_eager_mode(self, **kwargs): @@ -1102,7 +1131,12 @@ class TensorFlowTestCase(googletest.TestCase): np.testing.assert_allclose( a, b, rtol=rtol, atol=atol, err_msg=msg, equal_nan=True) - def _assertAllCloseRecursive(self, a, b, rtol=1e-6, atol=1e-6, path=None, + def _assertAllCloseRecursive(self, + a, + b, + rtol=1e-6, + atol=1e-6, + path=None, msg=None): path = path or [] path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "") @@ -1249,7 +1283,7 @@ class TensorFlowTestCase(googletest.TestCase): a = self._GetNdArray(a) b = self._GetNdArray(b) self.assertEqual(a.shape, b.shape, "Shape mismatch: expected %s, got %s." - " %s" % (a.shape, b.shape, msg)) + " %s" % (a.shape, b.shape, msg)) same = (a == b) if a.dtype == np.float32 or a.dtype == np.float64: @@ -1331,8 +1365,8 @@ class TensorFlowTestCase(googletest.TestCase): raise TypeError("np_array must be a Numpy ndarray or Numpy scalar") if not isinstance(tf_tensor, ops.Tensor): raise TypeError("tf_tensor must be a Tensor") - self.assertAllEqual(np_array.shape, tf_tensor.get_shape().as_list(), - msg=msg) + self.assertAllEqual( + np_array.shape, tf_tensor.get_shape().as_list(), msg=msg) def assertDeviceEqual(self, device1, device2, msg=None): """Asserts that the two given devices are the same. diff --git a/tensorflow/python/grappler/cluster.i b/tensorflow/python/grappler/cluster.i index 8079cb307bb1f5904b71bae891d5ef5f1e749e66..067c8213d4741936e4c28aaedf4f30639b8cdc41 100644 --- a/tensorflow/python/grappler/cluster.i +++ b/tensorflow/python/grappler/cluster.i @@ -206,7 +206,7 @@ static PyObject* TF_ListDevices(GCluster cluster) { return result; } -static std::vector TF_ListAvailableOps() { +static PyObject* TF_ListAvailableOps() { tensorflow::OpRegistry* registry = tensorflow::OpRegistry::Global(); std::vector ops; registry->GetRegisteredOps(&ops); @@ -215,7 +215,14 @@ static std::vector TF_ListAvailableOps() { op_names.push_back(op.name()); } std::sort(op_names.begin(), op_names.end()); - return op_names; + + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* result = PyList_New(op_names.size()); + for (int i = 0; i < op_names.size(); ++i) { + PyList_SetItem(result, i, PyString_FromString(op_names[i].c_str())); + } + PyGILState_Release(gstate); + return result; } static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item) { @@ -432,7 +439,7 @@ static GCluster TF_NewVirtualCluster( TF_Status* out_status); static void TF_ShutdownCluster(GCluster cluster); static PyObject* TF_ListDevices(GCluster cluster); -static std::vector TF_ListAvailableOps(); +static PyObject* TF_ListAvailableOps(); static PyObject* TF_GetSupportedDevices(GCluster cluster, GItem item); static float TF_EstimatePerformance(const tensorflow::NamedDevice& device); static PyObject* TF_MeasureCosts( diff --git a/tensorflow/python/grappler/cluster_test.py b/tensorflow/python/grappler/cluster_test.py index 10d515a36458d4025060cf4900251cd493f40795..a3c4c2bbeba7c4ee5d00268c0e475e11a31fa7eb 100644 --- a/tensorflow/python/grappler/cluster_test.py +++ b/tensorflow/python/grappler/cluster_test.py @@ -45,7 +45,7 @@ class ClusterTest(test.TestCase): op_perfs, run_time, step_stats = grappler_cluster.MeasureCosts( grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 7) + self.assertEqual(len(op_perfs), 8) self.assertTrue(step_stats.dev_stats) def testNoDetailedStats(self): @@ -125,14 +125,14 @@ class ClusterTest(test.TestCase): disable_detailed_stats=False, disable_timeline=False) as gcluster: op_perfs, run_time, step_stats = gcluster.MeasureCosts(grappler_item) self.assertTrue(run_time > 0) - self.assertEqual(len(op_perfs), 7) + self.assertEqual(len(op_perfs), 8) self.assertTrue(step_stats.dev_stats) def testAvailableOps(self): with cluster.Provision() as gcluster: op_names = gcluster.ListAvailableOps() - self.assertTrue(b'Add' in op_names) - self.assertTrue(b'MatMul' in op_names) + self.assertTrue('Add' in op_names) + self.assertTrue('MatMul' in op_names) self.assertEqual(op_names, sorted(op_names)) def testSupportDevices(self): diff --git a/tensorflow/python/grappler/controller.py b/tensorflow/python/grappler/controller.py new file mode 100644 index 0000000000000000000000000000000000000000..5677f4f52310dd68dc80c87275b50be95ba86b60 --- /dev/null +++ b/tensorflow/python/grappler/controller.py @@ -0,0 +1,142 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Controller Class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict + + +class Controller(object): + """Controller class.""" + + def __init__(self, item, cluster): + """Controller class initializer. + + Args: + item: The metagraph to place wrapped in a cluster. + cluster: A cluster of devices on which to place the item. + """ + self.item = item + + self._node = {} + for node in item.metagraph.graph_def.node: + self._node[node.name] = node + + self._fanout = defaultdict(lambda: []) + for node in item.metagraph.graph_def.node: + for fanin in self._get_node_fanin(node): + self._fanout[fanin.name].append(node) + + important_op_names = item.IdentifyImportantOps(sort_topologically=True) + + # List of important ops (these are the ops to place) sorted in topological + # order. The order of this collection is deterministic. + self.important_ops = [] + for name in important_op_names: + self.important_ops.append(self._node[name]) + + self.node_properties = item.GetOpProperties() + + self.cluster = cluster + self.devices = cluster.ListDevices() + + self.colocation_constraints = item.GetColocationGroups() + + self.placement_constraints = cluster.GetSupportedDevices(item) + for node_name, dev in self.placement_constraints.items(): + if len(dev) == 1: + # Place the node on the supported device + node = self._node[node_name] + node.device = dev[0] + fanout = self.get_node_fanout(node) + # Update the fanout of the fanin to bypass the node + for fanin in self._get_node_fanin(node): + fanout_of_fanin = self.get_node_fanout(fanin) + fanout_of_fanin += fanout + fanout_of_fanin.remove(node) + # Remove node from the list of important ops since we don't need to + # place the node. + if node in self.important_ops: + self.important_ops.remove(node) + important_op_names.remove(node.name) + + # List of important op names, in non deterministic order. + self.important_op_names = frozenset(important_op_names) + + @property + def input_graph_def(self): + return self.item.metagraph.graph_def + + @property + def num_devices(self): + return len(self.devices) + + def get_node_by_name(self, node_name): + return self._node[node_name] + + def get_node_fanout(self, node): + return self._fanout[node.name] + + def get_placements(self, *args, **kwargs): + """Returns: Two TF ops. + + Args: + *args: "". + **kwargs: "". + + Returns: + y_preds: tensor of size [batch_size, num_ops] + log_probs: python dict of at least two fields: "sample", "target" each + containing a tensor of size [batch_size], corresponding to the log_probs. + """ + raise NotImplementedError + + def eval_placement(self, sess, *args, **kwargs): + """At this time, this method evaluates ONLY ONE placement. + + Args: + sess: a tf.Session() object used to retrieve cached assignment info. + *args: "". + **kwargs: "". + + Returns: + run_time: scalar + """ + raise NotImplementedError + + def export_placement(self, metagraph): + """Annotate the placement onto the specified metagraph. + + Args: + metagraph: the metagraph to annotate with the placement. + """ + for node in metagraph.graph_def.node: + if node.name in self.important_op_names: + node.device = self.get_node_by_name(node.name).device + + # Get the nodes in the immediate fanin of node. + # Beware: this doesn't take into account the nodes that may be skipped + # since placement constraints force their placement. + def _get_node_fanin(self, node): + input_ops = [] + for fanin_name in node.input: + if fanin_name[0] == "^": + fanin_name = fanin_name[1:] + fanin_name = fanin_name.split(":")[0] + input_ops.append(self.get_node_by_name(fanin_name)) + return input_ops diff --git a/tensorflow/python/grappler/cost_analyzer.cc b/tensorflow/python/grappler/cost_analyzer.cc index 88bf900dca6d97773959eb309a4a3c5931fdcb88..b474e19894957d01c7c8978282c547df81a9b2b3 100644 --- a/tensorflow/python/grappler/cost_analyzer.cc +++ b/tensorflow/python/grappler/cost_analyzer.cc @@ -30,11 +30,12 @@ CostAnalyzer::CostAnalyzer(const GrapplerItem& item, Cluster* cluster, analytical_estimator_(cluster, false), suffix_(suffix) {} -Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report) { +Status CostAnalyzer::GenerateReport(std::ostream& os, bool per_node_report, + bool verbose) { GatherCosts(); PreprocessCosts(); AnalyzeCosts(); - PrintAnalysis(os, per_node_report); + PrintAnalysis(os, per_node_report, verbose); return Status::OK(); } @@ -158,7 +159,8 @@ void CostAnalyzer::AnalyzeCosts() { } } -void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const { +void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report, + bool verbose) const { os << std::endl; os << std::left << std::setw(50) << "Total time measured in ns (serialized): " << std::right @@ -227,10 +229,55 @@ void CostAnalyzer::PrintAnalysis(std::ostream& os, bool per_node_report) const { os << std::endl; if (per_node_report) { - os << "Below is the per-node report:" << std::endl; - os << op_perf_.DebugString(); + if (verbose) { + os << "Below is the full per-node report:" << std::endl; + os << op_perf_.DebugString(); + } else { + os << "Below is the per-node report summary:" << std::endl; + int width = 35; + int width_narrow = 15; + int width_wide = 20; + os << std::setw(width + 1) << "Op,"; + os << std::setw(width_wide + 1) << "Measured time (ns),"; + os << std::setw(width_wide + 1) << "Compute time (ns),"; + os << std::setw(width_wide + 1) << "Memory time (ns),"; + os << std::setw(width_narrow + 2) << "Compute eff,"; + os << std::setw(width_narrow + 2) << "Memory eff,"; + os << " Inputs" << std::endl; + for (int i = 0; i < op_perf_.op_performance_size(); i++) { + const auto& perf = op_perf_.op_performance(i); + string op_name = perf.op().op(); + os << std::setw(width) << op_name << ","; + os << std::setw(width_wide) << perf.compute_cost() << ","; + os << std::setw(width_wide) << perf.compute_time() << ","; + os << std::setw(width_wide) << perf.memory_time() << ","; + os << std::setw(width_narrow) << std::setprecision(2) + << perf.compute_efficiency() * 100 << "%,"; + os << std::setw(width_narrow) << std::setprecision(2) + << perf.memory_efficiency() * 100 << "%,"; + os << " ["; + for (int j = 0; j < perf.op().inputs_size(); j++) { + const auto& shape = perf.op().inputs(j).shape(); + if (shape.dim_size() > 0) { + os << "("; + std::vector dims; + for (int k = 0; k < shape.dim_size(); k++) { + os << shape.dim(k).size(); + if (k < shape.dim_size() - 1) { + os << ", "; + } + } + os << ")"; + if (j < perf.op().inputs_size() - 1) { + os << ", "; + } + } + } + os << "]" << std::endl; + } + os << std::endl; + } } } - } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/python/grappler/cost_analyzer.h b/tensorflow/python/grappler/cost_analyzer.h index 0e860e0fee9923510292d3cf1a8069435787476f..b5364aa37ab2fbbeb0a33e6764539cca795f2fa6 100644 --- a/tensorflow/python/grappler/cost_analyzer.h +++ b/tensorflow/python/grappler/cost_analyzer.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/core/framework/cost_graph.pb.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/clusters/cluster.h" #include "tensorflow/core/grappler/costs/analytical_cost_estimator.h" #include "tensorflow/core/grappler/costs/cost_estimator.h" @@ -50,7 +51,7 @@ class CostAnalyzer { public: explicit CostAnalyzer(const GrapplerItem& item, Cluster* cluster, const string& suffix); - Status GenerateReport(std::ostream& os, bool per_node_report); + Status GenerateReport(std::ostream& os, bool per_node_report, bool verbose); private: void PredictCosts(CostEstimator* cost_estimator, CostGraphDef* cost_graph, @@ -59,7 +60,8 @@ class CostAnalyzer { void PreprocessCosts(); void AnalyzeCosts(); void SortOpsByTime(std::map ops); - void PrintAnalysis(std::ostream& os, bool per_node_report) const; + void PrintAnalysis(std::ostream& os, bool per_node_report, + bool verbose) const; const GrapplerItem* item_; MeasuringCostEstimator measure_estimator_; diff --git a/tensorflow/python/grappler/cost_analyzer.i b/tensorflow/python/grappler/cost_analyzer.i index 4c0953435ba3fa6423bbc869fcca909d0c2ccb25..8f7fdb47f267bea582e371eb9ea6982b6e9341ad 100644 --- a/tensorflow/python/grappler/cost_analyzer.i +++ b/tensorflow/python/grappler/cost_analyzer.i @@ -44,7 +44,7 @@ limitations under the License. %{ string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report, - GCluster cluster) { + bool verbose, GCluster cluster) { tensorflow::grappler::ItemConfig cfg; cfg.apply_optimizations = false; std::unique_ptr item = @@ -57,11 +57,11 @@ string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_no tensorflow::grappler::CostAnalyzer analyzer(*item, cluster.get(), suffix); std::stringstream os; - analyzer.GenerateReport(os, per_node_report); + analyzer.GenerateReport(os, per_node_report, verbose); return os.str(); } %} string GenerateCostReport(const tensorflow::MetaGraphDef& metagraph, bool per_node_report, - GCluster cluster); + bool verbose, GCluster cluster); diff --git a/tensorflow/python/grappler/cost_analyzer.py b/tensorflow/python/grappler/cost_analyzer.py index a1ff915c61ba14d9a899d7f6c9a2c49855969b00..6a4690e91ba981706eed0d9fdfae2e64359d0416 100644 --- a/tensorflow/python/grappler/cost_analyzer.py +++ b/tensorflow/python/grappler/cost_analyzer.py @@ -24,7 +24,10 @@ from tensorflow.python.grappler import cluster as gcluster from tensorflow.python.grappler import item as gitem -def GenerateCostReport(metagraph, per_node_report=False, cluster=None): +def GenerateCostReport(metagraph, + per_node_report=False, + verbose=False, + cluster=None): """Analyze the cost of each TensorFlow op and node in the provided metagraph. Args: @@ -32,6 +35,7 @@ def GenerateCostReport(metagraph, per_node_report=False, cluster=None): per_node_report: by default the report contains stats aggregated on a per op type basis, setting per_node_report to True adds results for each individual node to the report. + verbose: Prints out the entire operation proto instead of a summary table. cluster: Analyze the costs using the specified cluster, or the local machine if no cluster was specified. @@ -42,8 +46,9 @@ def GenerateCostReport(metagraph, per_node_report=False, cluster=None): cluster = gcluster.Cluster(disable_detailed_stats=False) with errors.raise_exception_on_not_ok_status(): - ret_from_swig = tf_wrap.GenerateCostReport( - metagraph.SerializeToString(), per_node_report, cluster.tf_cluster) + ret_from_swig = tf_wrap.GenerateCostReport(metagraph.SerializeToString(), + per_node_report, verbose, + cluster.tf_cluster) return ret_from_swig diff --git a/tensorflow/python/grappler/cost_analyzer_test.py b/tensorflow/python/grappler/cost_analyzer_test.py index 511908c79ce47d6849bf97d11bc42f2f1bb13f18..b8225b81a52f1a2ee10663544d54f1c9bd7ee785 100644 --- a/tensorflow/python/grappler/cost_analyzer_test.py +++ b/tensorflow/python/grappler/cost_analyzer_test.py @@ -48,7 +48,7 @@ class CostAnalysisTest(test.TestCase): train_op.append(d) mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) - report = cost_analyzer.GenerateCostReport(mg) + report = cost_analyzer.GenerateCostReport(mg, per_node_report=True) # Check the report headers self.assertTrue(b"Total time measured in ns (serialized):" in report) @@ -57,6 +57,26 @@ class CostAnalysisTest(test.TestCase): self.assertTrue(b"Total time analytical in ns (lower bound):" in report) self.assertTrue(b"Overall efficiency (analytical upper/actual):" in report) self.assertTrue(b"Overall efficiency (analytical lower/actual):" in report) + self.assertTrue(b"Below is the per-node report summary:" in report) + + # Also print the report to make it easier to debug + print("{}".format(report)) + + def testVerbose(self): + """Make sure the full report is generated with verbose=True.""" + a = constant_op.constant(10, name="a") + b = constant_op.constant(20, name="b") + c = math_ops.add_n([a, b], name="c") + d = math_ops.add_n([b, c], name="d") + train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=ops.get_default_graph()) + + report = cost_analyzer.GenerateCostReport( + mg, per_node_report=True, verbose=True) + + # Check the report headers + self.assertTrue(b"Below is the full per-node report:" in report) # Also print the report to make it easier to debug print("{}".format(report)) diff --git a/tensorflow/python/grappler/cost_analyzer_tool.py b/tensorflow/python/grappler/cost_analyzer_tool.py index 51b77b471b09d59f1a63b5cc3b736a8f2462351d..0853db252406966cec36b63efafec9ec755c7e87 100644 --- a/tensorflow/python/grappler/cost_analyzer_tool.py +++ b/tensorflow/python/grappler/cost_analyzer_tool.py @@ -35,11 +35,20 @@ from tensorflow.python.platform import gfile from tensorflow.python.training import saver -def main(_): +def get_metagraph(): + """Constructs and returns a MetaGraphDef from the input file.""" if FLAGS.metagraphdef: with gfile.GFile(FLAGS.metagraphdef) as meta_file: metagraph = meta_graph_pb2.MetaGraphDef() - metagraph.ParseFromString(meta_file.read()) + if FLAGS.metagraphdef.endswith(".pbtxt"): + text_format.Merge(meta_file.read(), metagraph) + else: + metagraph.ParseFromString(meta_file.read()) + if FLAGS.fetch is not None: + fetch_collection = meta_graph_pb2.CollectionDef() + for fetch in FLAGS.fetch.split(","): + fetch_collection.node_list.value.append(fetch) + metagraph.collection_def["train_op"].CopyFrom(fetch_collection) else: with gfile.GFile(FLAGS.graphdef) as graph_file: graph_def = graph_pb2.GraphDef() @@ -49,21 +58,28 @@ def main(_): graph_def.ParseFromString(graph_file.read()) importer.import_graph_def(graph_def, name="") graph = ops.get_default_graph() - fetch = graph.get_operation_by_name(FLAGS.fetch) - graph.add_to_collection("train_op", fetch) + for fetch in FLAGS.fetch.split(","): + fetch_op = graph.get_operation_by_name(fetch) + graph.add_to_collection("train_op", fetch_op) metagraph = saver.export_meta_graph( graph_def=graph.as_graph_def(), graph=graph) + return metagraph + +def main(_): + metagraph = get_metagraph() rewriter_config = rewriter_config_pb2.RewriterConfig() if FLAGS.rewriter_config is not None: text_format.Merge(FLAGS.rewriter_config, rewriter_config) optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, metagraph) metagraph.graph_def.CopyFrom(optimized_graph) - report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report) - print(report) - report = cost_analyzer.GenerateMemoryReport(metagraph) + report = cost_analyzer.GenerateCostReport(metagraph, FLAGS.per_node_report, + FLAGS.verbose) print(report) + if FLAGS.memory_report: + report = cost_analyzer.GenerateMemoryReport(metagraph) + print(report) if __name__ == "__main__": @@ -78,16 +94,11 @@ if __name__ == "__main__": type=str, default=None, help="Input .pb GraphDef file path.") - # Consider making flag fetch work together with flag metagraphdef. As some - # MetaGraphDef files don't have collection train_op. parser.add_argument( "--fetch", type=str, default=None, - help= - "The name of the fetch node. This flag is ignored if flag " - "metagraphdef is used." - ) + help="The names of the fetch node delimited by comma.") parser.add_argument( "--rewriter_config", type=str, @@ -103,5 +114,13 @@ if __name__ == "__main__": help="Generate per-node report. By default the report contains stats " "aggregated on a per op type basis, per_node_report adds results " "for each individual node to the report.") + parser.add_argument( + "--memory_report", + action="store_true", + help="Generate memory usage report.") + parser.add_argument( + "--verbose", + action="store_true", + help="Generate verbose reports. By default, succinct reports are used.") FLAGS, unparsed = parser.parse_known_args() app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/python/grappler/graph_placer.py b/tensorflow/python/grappler/graph_placer.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd51df4d962583555e08ae973ab43d15ba01997 --- /dev/null +++ b/tensorflow/python/grappler/graph_placer.py @@ -0,0 +1,120 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Graph Placer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.grappler import cluster as gcluster +from tensorflow.python.grappler import hierarchical_controller +from tensorflow.python.grappler import item as gitem +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import training + + +def PlaceGraph(metagraph, + cluster=None, + allotted_time=3600, + hparams=None, + verbose=False): + """Place the provided metagraph. + + Args: + metagraph: the metagraph to place. + cluster: an optional set of hardware resource to optimize the placement for. + If none is specified, we'll optimize the placement for the hardware + available on the local machine. + allotted_time: the maximum amount to time in seconds to spend optimizing + the placement. + hparams: hyperparameters used to fine tune the placer. + verbose: prints debug information if True. + + Returns: + The placed metagraph. + """ + if cluster is None: + cluster = gcluster.Cluster() + + # Optimize the metagraph to speedup the placement + rewriter_config = rewriter_config_pb2.RewriterConfig() + rewriter_config.optimizers.append("pruning") + rewriter_config.optimizers.append("constfold") + rewriter_config.optimizers.append("arithmetic") + rewriter_config.optimizers.append("dependency") + rewriter_config.optimizers.append("pruning") + optimized_graph = tf_optimizer.OptimizeGraph( + rewriter_config, metagraph, verbose=verbose, cluster=cluster) + optimized_metagraph = meta_graph_pb2.MetaGraphDef() + optimized_metagraph.CopyFrom(metagraph) + optimized_metagraph.graph_def.CopyFrom(optimized_graph) + + item = gitem.Item(optimized_metagraph) + + # Measure the runtime achievable with the original placement. + try: + _, original_run_time, _ = cluster.MeasureCosts(item) + if verbose: + print("Runtime for original placement: " + str(original_run_time)) + except errors.OpError as e: + if verbose: + print("Original placement isn't feasible: " + str(e)) + original_run_time = hparams.failing_signal + + if hparams is None: + hparams = hierarchical_controller.hierarchical_controller_hparams() + # We run with a single child + hparams.num_children = 1 + + with tf_ops.Graph().as_default(): + # Place all the nodes of the controller on the CPU. We don't want them to + # fight for accelerator memory with the model to optimize. + with tf_ops.device("/device:CPU:0"): + model = hierarchical_controller.HierarchicalController( + hparams, item, cluster) + ops = model.build_controller() + session_creator = training.ChiefSessionCreator() + with training.MonitoredSession(session_creator=session_creator) as sess: + start_time = time.time() + current_time = start_time + while current_time - start_time < allotted_time: + grouping_actions = model.generate_grouping(sess) + input_to_seq2seq = model.create_group_embeddings( + grouping_actions, verbose=verbose) + model.generate_placement(input_to_seq2seq, sess) + try: + run_time = model.eval_placement( + sess, + verbose=verbose) + except errors.OpError as e: + if verbose: + print("Failed to run graph:" + str(e)) + run_time = hparams.failing_signal + updated = model.update_reward(sess, run_time, verbose=verbose) + if updated and run_time < original_run_time: + if verbose: + print("Found better placement, with runtime " + str(run_time)) + model.export_placement(metagraph) + + model.process_reward(sess) + + current_time = time.time() + + return metagraph diff --git a/tensorflow/python/grappler/graph_placer_test.py b/tensorflow/python/grappler/graph_placer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9eabe3cd5437022eb3b98010d0f384cc9f6bac2a --- /dev/null +++ b/tensorflow/python/grappler/graph_placer_test.py @@ -0,0 +1,140 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests the graph placer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.core.protobuf import device_properties_pb2 +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import meta_graph +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.grappler import cluster +from tensorflow.python.grappler import graph_placer +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class GraphPlacerTest(test.TestCase): + + @staticmethod + def _buildMnist(batch_size=128, + input_size=256, + num_classes=1024, + num_layers=10, + hidden_size=256, + name='mnist'): + g = tf_ops.get_default_graph() + with g.as_default(): + ops = {} + x = random_ops.random_uniform( + [batch_size, input_size], -0.1, 0.1, dtype=dtypes.float32) + for layer_id in range(num_layers): + with variable_scope.variable_scope('layer_{}'.format(layer_id)): + a = input_size if layer_id == 0 else hidden_size + b = hidden_size if layer_id < num_layers - 1 else num_classes + w = variable_scope.get_variable('w', [a, b]) + x = math_ops.matmul(x, w) + x = nn_ops.relu(x) + ops['y_preds'] = math_ops.argmax(x, axis=1) + + train_op = g.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP) + train_op.append(ops['y_preds']) + return g + + @staticmethod + def _buildCluster(num_cpus=1, num_gpus=1): + devices = [] + if num_gpus > 0: + device_properties = device_properties_pb2.DeviceProperties( + type='GPU', + vendor='NVidia', + model='GeForce GTX TITAN X', + frequency=1076, + num_cores=24, + environment={'architecture': '5.2', + 'cuda': '8000', + 'cudnn': '6021'}, + num_registers=65536, + l1_cache_size=24576, + l2_cache_size=3145728, + shared_memory_size_per_multiprocessor=98304, + memory_size=12783648768, + bandwidth=336480000) + for i in range(num_gpus): + devices.append( + device_properties_pb2.NamedDevice( + properties=device_properties, name='/GPU:' + str(i))) + + assert num_cpus > 0 + device_properties = device_properties_pb2.DeviceProperties( + type='CPU', + frequency=2000, + num_cores=4, + l1_cache_size=32768, + l2_cache_size=262144, + l3_cache_size=12582912) + for i in range(num_cpus): + devices.append( + device_properties_pb2.NamedDevice( + properties=device_properties, name='/CPU:' + str(i))) + + return cluster.Cluster(devices=devices) + + def testBasic(self): + """Place a trivial graph.""" + a = constant_op.constant(10, name='a') + b = constant_op.constant(20, name='b') + c = math_ops.add_n([a, b], name='c') + d = math_ops.add_n([b, c], name='d') + train_op = tf_ops.get_collection_ref(tf_ops.GraphKeys.TRAIN_OP) + train_op.append(d) + mg = meta_graph.create_meta_graph_def(graph=tf_ops.get_default_graph()) + + gcluster = cluster.Cluster() + placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster) + + self.assertEqual(4, len(placed_mg.graph_def.node)) + self.assertItemsEqual([node.name for node in placed_mg.graph_def.node], + [node.name for node in mg.graph_def.node]) + + available_devices = [device.name for device in gcluster.ListDevices()] + for node in placed_mg.graph_def.node: + # The constant nodes are optimized away before the placer is run, and + # therefore won't be placed. + self.assertTrue(not node.device or node.device in available_devices) + + def testMNIST(self): + graph = GraphPlacerTest._buildMnist() + mg = meta_graph.create_meta_graph_def(graph=graph) + gcluster = GraphPlacerTest._buildCluster(num_gpus=1) + # Spend 15 seconds trying to optimize the placement of the model. This + # should give us enough time to exercise the code, but not enough to find + # a good placement, so we'll just check for legality. + placed_mg = graph_placer.PlaceGraph(mg, allotted_time=15, cluster=gcluster) + self.assertEqual(len(placed_mg.graph_def.node), len(mg.graph_def.node)) + self.assertItemsEqual([node.name for node in placed_mg.graph_def.node], + [node.name for node in mg.graph_def.node]) + available_devices = [device.name for device in gcluster.ListDevices()] + for node in placed_mg.graph_def.node: + self.assertTrue(not node.device or node.device in available_devices) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/grappler/hierarchical_controller.py b/tensorflow/python/grappler/hierarchical_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..b06fb3c6d0666659031863b90212e9456d044c14 --- /dev/null +++ b/tensorflow/python/grappler/hierarchical_controller.py @@ -0,0 +1,1098 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""HierarchicalController Class. + +The HierarchicalController encompasses the entire lifecycle of training the +device placement policy, including generating op embeddings, getting groups for +each op, placing those groups and running the predicted placements. + +Different assignment models can inherit from this class. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import numpy as np +import six +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.grappler.controller import Controller +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import embedding_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import tensor_array_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import training_util + + +class PlacerParams(object): + """Class to hold a set of placement parameters as name-value pairs. + + A typical usage is as follows: + + ```python + # Create a PlacerParams object specifying names and values of the model + # parameters: + params = PlacerParams(hidden_size=128, decay_steps=50) + + # The parameters are available as attributes of the PlacerParams object: + hparams.hidden_size ==> 128 + hparams.decay_steps ==> 50 + ``` + + """ + + def __init__(self, **kwargs): + """Create an instance of `PlacerParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the parameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `PlacerParams` object, + and they can be accessed directly with the dot notation `params._name_`. + + Example: + + ```python + # Define 1 parameter: 'hidden_size' + params = PlacerParams(hidden_size=128) + params.hidden_size ==> 128 + ``` + + Args: + **kwargs: Key-value pairs where the key is the parameter name and + the value is the value for the parameter. + """ + for name, value in six.iteritems(kwargs): + self.add_param(name, value) + + def add_param(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could be the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # parameter name. + if getattr(self, name, None) is not None: + raise ValueError("Parameter name is reserved: %s" % name) + setattr(self, name, value) + + +def hierarchical_controller_hparams(): + """Hyperparameters for hierarchical planner.""" + return PlacerParams( + hidden_size=512, + forget_bias_init=1.0, + temperature=1.0, + logits_std_noise=0.5, + stop_noise_step=750, + decay_steps=50, + max_num_outputs=5, + max_output_size=5, + tanh_constant=1.0, + adj_embed_dim=20, + grouping_hidden_size=64, + num_groups=None, + bi_lstm=True, + failing_signal=100, + stop_sampling=500, + start_with_failing_signal=True, + always_update_baseline=False, + bl_dec=0.9, + grad_bound=1.0, + lr=0.1, + lr_dec=0.95, + start_decay_step=400, + optimizer_type="adam", + stop_updating_after_steps=1000, + name="hierarchical_controller", + keep_prob=1.0, + reward_function="sqrt", + seed=1234, + # distributed training params + num_children=1) + + +class HierarchicalController(Controller): + """HierarchicalController class.""" + + def __init__(self, hparams, item, cluster, controller_id=0): + """HierarchicalController class initializer. + + Args: + hparams: All hyper-parameters. + item: The metagraph to place. + cluster: The cluster of hardware devices to optimize for. + controller_id: the id of the controller in a multi-controller setup. + """ + super(HierarchicalController, self).__init__(item, cluster) + self.ctrl_id = controller_id + self.hparams = hparams + + if self.hparams.num_groups is None: + self.num_groups = min(256, 20 * self.num_devices) + else: + self.num_groups = self.hparams.num_groups + + # creates self.op_embeddings and self.type_dict + self.create_op_embeddings(verbose=False) + # TODO(azalia) clean up embedding/group_embedding_size names + self.group_emb_size = ( + 2 * self.num_groups + len(self.type_dict) + + self.hparams.max_num_outputs * self.hparams.max_output_size) + self.embedding_size = self.group_emb_size + self.initializer = init_ops.glorot_uniform_initializer( + seed=self.hparams.seed) + + with variable_scope.variable_scope( + self.hparams.name, + initializer=self.initializer, + reuse=variable_scope.AUTO_REUSE): + # define parameters of feedforward + variable_scope.get_variable("w_grouping_ff", [ + 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + + self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size + ]) + variable_scope.get_variable( + "w_grouping_softmax", + [self.hparams.grouping_hidden_size, self.num_groups]) + if self.hparams.bi_lstm: + variable_scope.get_variable("encoder_lstm_forward", [ + self.embedding_size + self.hparams.hidden_size / 2, + 2 * self.hparams.hidden_size + ]) + variable_scope.get_variable("encoder_lstm_backward", [ + self.embedding_size + self.hparams.hidden_size / 2, + 2 * self.hparams.hidden_size + ]) + variable_scope.get_variable( + "device_embeddings", [self.num_devices, self.hparams.hidden_size]) + variable_scope.get_variable( + "decoder_lstm", + [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) + variable_scope.get_variable( + "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) + variable_scope.get_variable("device_go_embedding", + [1, self.hparams.hidden_size]) + variable_scope.get_variable( + "encoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "decoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable( + "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) + + else: + variable_scope.get_variable("encoder_lstm", [ + self.embedding_size + self.hparams.hidden_size, + 4 * self.hparams.hidden_size + ]) + variable_scope.get_variable( + "device_embeddings", [self.num_devices, self.hparams.hidden_size]) + variable_scope.get_variable( + "decoder_lstm", + [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) + variable_scope.get_variable( + "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) + variable_scope.get_variable("device_go_embedding", + [1, self.hparams.hidden_size]) + variable_scope.get_variable( + "encoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "decoder_forget_bias", + shape=1, + dtype=dtypes.float32, + initializer=init_ops.constant_initializer( + self.hparams.forget_bias_init)) + variable_scope.get_variable( + "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable( + "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) + variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) + seq2seq_input_layer = array_ops.placeholder_with_default( + array_ops.zeros([1, self.num_groups, self.group_emb_size], + dtypes.float32), + shape=(1, self.num_groups, self.group_emb_size)) + self.seq2seq_input_layer = seq2seq_input_layer + + def compute_reward(self, run_time): + if self.hparams.reward_function == "id": + reward = run_time + elif self.hparams.reward_function == "sqrt": + reward = math.sqrt(run_time) + elif self.hparams.reward_function == "log": + reward = math.log1p(run_time) + else: + raise NotImplementedError( + "Unrecognized reward function '%s', consider your " + "--reward_function flag value." % self.hparams.reward_function) + return reward + + def build_controller(self): + """RL optimization interface. + + Returns: + ops: A dictionary holding handles of the model used for training. + """ + + self._global_step = training_util.get_or_create_global_step() + ops = {} + ops["loss"] = 0 + + failing_signal = self.compute_reward(self.hparams.failing_signal) + + ctr = {} + + with tf_ops.name_scope("controller_{}".format(self.ctrl_id)): + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + ctr["reward"] = {"value": [], "ph": [], "update": []} + ctr["ready"] = {"value": [], "ph": [], "update": []} + ctr["best_reward"] = {"value": [], "update": []} + for i in range(self.hparams.num_children): + reward_value = variable_scope.get_local_variable( + "reward_{}".format(i), + initializer=0.0, + dtype=dtypes.float32, + trainable=False) + reward_ph = array_ops.placeholder( + dtypes.float32, shape=(), name="reward_ph_{}".format(i)) + reward_update = state_ops.assign( + reward_value, reward_ph, use_locking=True) + ctr["reward"]["value"].append(reward_value) + ctr["reward"]["ph"].append(reward_ph) + ctr["reward"]["update"].append(reward_update) + best_reward = variable_scope.get_local_variable( + "best_reward_{}".format(i), + initializer=failing_signal, + dtype=dtypes.float32, + trainable=False) + ctr["best_reward"]["value"].append(best_reward) + ctr["best_reward"]["update"].append( + state_ops.assign(best_reward, + math_ops.minimum(best_reward, reward_update))) + + ready_value = variable_scope.get_local_variable( + "ready_{}".format(i), + initializer=True, + dtype=dtypes.bool, + trainable=False) + ready_ph = array_ops.placeholder( + dtypes.bool, shape=(), name="ready_ph_{}".format(i)) + ready_update = state_ops.assign( + ready_value, ready_ph, use_locking=True) + ctr["ready"]["value"].append(ready_value) + ctr["ready"]["ph"].append(ready_ph) + ctr["ready"]["update"].append(ready_update) + + ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings() + summary.histogram( + "grouping_actions", + array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0], + [1, array_ops.shape(self.op_embeddings)[0]])) + + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + ctr["baseline"] = variable_scope.get_local_variable( + "baseline", + initializer=failing_signal + if self.hparams.start_with_failing_signal else 0.0, + dtype=dtypes.float32, + trainable=False) + + new_baseline = self.hparams.bl_dec * ctr["baseline"] + ( + 1 - self.hparams.bl_dec) * math_ops.reduce_mean( + ctr["reward"]["value"]) + if not self.hparams.always_update_baseline: + baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal) + selected_reward = array_ops.boolean_mask(ctr["reward"]["value"], + baseline_mask) + selected_baseline = control_flow_ops.cond( + math_ops.reduce_any(baseline_mask), + lambda: math_ops.reduce_mean(selected_reward), + lambda: constant_op.constant(0, dtype=dtypes.float32)) + ctr["pos_reward"] = selected_baseline + pos_ = math_ops.less( + constant_op.constant(0, dtype=dtypes.float32), selected_baseline) + selected_baseline = self.hparams.bl_dec * ctr["baseline"] + ( + 1 - self.hparams.bl_dec) * selected_baseline + selected_baseline = control_flow_ops.cond( + pos_, lambda: selected_baseline, lambda: ctr["baseline"]) + new_baseline = control_flow_ops.cond( + math_ops.less(self.global_step, + self.hparams.stop_updating_after_steps), + lambda: new_baseline, lambda: selected_baseline) + ctr["baseline_update"] = state_ops.assign( + ctr["baseline"], new_baseline, use_locking=True) + + ctr["y_preds"], ctr["log_probs"] = self.get_placements() + summary.histogram("actions", ctr["y_preds"]["sample"]) + mask = math_ops.less(ctr["reward"]["value"], failing_signal) + ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"] + ctr["loss"] *= ( + ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"]) + + selected_loss = array_ops.boolean_mask(ctr["loss"], mask) + selected_loss = control_flow_ops.cond( + math_ops.reduce_any(mask), + lambda: math_ops.reduce_mean(-selected_loss), + lambda: constant_op.constant(0, dtype=dtypes.float32)) + + ctr["loss"] = control_flow_ops.cond( + math_ops.less(self.global_step, + self.hparams.stop_updating_after_steps), + lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss) + + ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"]) + summary.scalar("loss", ctr["loss"]) + summary.scalar("avg_reward", ctr["reward_s"]) + summary.scalar("best_reward_so_far", best_reward) + summary.scalar( + "advantage", + math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"])) + + with variable_scope.variable_scope( + "optimizer", reuse=variable_scope.AUTO_REUSE): + (ctr["train_op"], ctr["lr"], ctr["grad_norm"], + ctr["grad_norms"]) = self._get_train_ops( + ctr["loss"], + tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES), + self.global_step, + grad_bound=self.hparams.grad_bound, + lr_init=self.hparams.lr, + lr_dec=self.hparams.lr_dec, + start_decay_step=self.hparams.start_decay_step, + decay_steps=self.hparams.decay_steps, + optimizer_type=self.hparams.optimizer_type) + + summary.scalar("gradnorm", ctr["grad_norm"]) + summary.scalar("lr", ctr["lr"]) + ctr["summary"] = summary.merge_all() + ops["controller"] = ctr + + self.ops = ops + return ops + + @property + def global_step(self): + return self._global_step + + def create_op_embeddings(self, verbose=False): + if verbose: + print("process input graph for op embeddings") + self.num_ops = len(self.important_ops) + # topological sort of important nodes + topo_order = [op.name for op in self.important_ops] + + # create index to name for topologicaly sorted important nodes + name_to_topo_order_index = {} + for idx, x in enumerate(topo_order): + name_to_topo_order_index[x] = idx + self.name_to_topo_order_index = name_to_topo_order_index + + # create adj matrix + adj_dict = {} + for idx, op in enumerate(self.important_ops): + for output_op in self.get_node_fanout(op): + output_op_name = output_op.name + if output_op_name in self.important_op_names: + if name_to_topo_order_index[op.name] not in adj_dict: + adj_dict[name_to_topo_order_index[op.name]] = [] + adj_dict[name_to_topo_order_index[op.name]].extend( + [name_to_topo_order_index[output_op_name], 1]) + if output_op_name not in adj_dict: + adj_dict[name_to_topo_order_index[output_op_name]] = [] + adj_dict[name_to_topo_order_index[output_op_name]].extend( + [name_to_topo_order_index[op.name], -1]) + + # get op_type op_output_shape, and adj info + output_embed_dim = (self.hparams.max_num_outputs * + self.hparams.max_output_size) + + # TODO(bsteiner): don't filter based on used ops so that we can generalize + # to models that use other types of ops. + used_ops = set() + for node in self.important_ops: + op_type = str(node.op) + used_ops.add(op_type) + + self.type_dict = {} + for op_type in self.cluster.ListAvailableOps(): + if op_type in used_ops: + self.type_dict[op_type] = len(self.type_dict) + + op_types = np.zeros([self.num_ops], dtype=np.int32) + op_output_shapes = np.full( + [self.num_ops, output_embed_dim], -1.0, dtype=np.float32) + for idx, node in enumerate(self.important_ops): + op_types[idx] = self.type_dict[node.op] + # output shape + op_name = node.name + for i, output_prop in enumerate(self.node_properties[op_name]): + if output_prop.shape.__str__() == "": + continue + shape = output_prop.shape + for j, dim in enumerate(shape.dim): + if dim.size >= 0: + if i * self.hparams.max_output_size + j >= output_embed_dim: + break + op_output_shapes[idx, + i * self.hparams.max_output_size + j] = dim.size + # adj for padding + op_adj = np.full( + [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32) + for idx in adj_dict: + neighbors = adj_dict[int(idx)] + min_dim = min(self.hparams.adj_embed_dim, len(neighbors)) + padding_size = self.hparams.adj_embed_dim - min_dim + neighbors = neighbors[:min_dim] + [0] * padding_size + op_adj[int(idx)] = neighbors + + # op_embedding starts here + op_embeddings = np.zeros( + [ + self.num_ops, + 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + + self.hparams.adj_embed_dim + ], + dtype=np.float32) + for idx, op_name in enumerate(topo_order): + op_embeddings[idx] = np.concatenate( + (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)])) + self.op_embeddings = constant_op.constant( + op_embeddings, dtype=dtypes.float32) + if verbose: + print("num_ops = {}".format(self.num_ops)) + print("num_types = {}".format(len(self.type_dict))) + + def get_groupings(self, *args, **kwargs): + num_children = self.hparams.num_children + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + grouping_actions_cache = variable_scope.get_local_variable( + "grouping_actions_cache", + initializer=init_ops.zeros_initializer, + dtype=dtypes.int32, + shape=[num_children, self.num_ops], + trainable=False) + input_layer = self.op_embeddings + input_layer = array_ops.expand_dims(input_layer, 0) + feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1]) + grouping_actions, grouping_log_probs = {}, {} + grouping_actions["sample"], grouping_log_probs[ + "sample"] = self.make_grouping_predictions(feed_ff_input_layer) + + grouping_actions["sample"] = state_ops.assign(grouping_actions_cache, + grouping_actions["sample"]) + self.grouping_actions_cache = grouping_actions_cache + + return grouping_actions, grouping_log_probs + + def make_grouping_predictions(self, input_layer, reuse=None): + """model that predicts grouping (grouping_actions). + + Args: + input_layer: group_input_layer + reuse: reuse + + Returns: + grouping_actions: actions + grouping_log_probs: log probabilities corresponding to actions + """ + with variable_scope.variable_scope(self.hparams.name, reuse=True): + # input_layer: tensor of size [1, num_ops, hidden_size] + w_grouping_ff = variable_scope.get_variable("w_grouping_ff") + w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax") + + batch_size = array_ops.shape(input_layer)[0] + embedding_dim = array_ops.shape(input_layer)[2] + + reshaped = array_ops.reshape(input_layer, + [batch_size * self.num_ops, embedding_dim]) + ff_output = math_ops.matmul(reshaped, w_grouping_ff) + logits = math_ops.matmul(ff_output, w_grouping_softmax) + if self.hparams.logits_std_noise > 0: + num_in_logits = math_ops.cast( + array_ops.size(logits), dtype=dtypes.float32) + avg_norm = math_ops.divide( + linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) + logits_noise = random_ops.random_normal( + array_ops.shape(logits), + stddev=self.hparams.logits_std_noise * avg_norm) + logits = control_flow_ops.cond( + self.global_step > self.hparams.stop_noise_step, lambda: logits, + lambda: logits + logits_noise) + logits = array_ops.reshape(logits, + [batch_size * self.num_ops, self.num_groups]) + actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed) + actions = math_ops.to_int32(actions) + actions = array_ops.reshape(actions, [batch_size, self.num_ops]) + action_label = array_ops.reshape(actions, [-1]) + log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=action_label) + log_probs = array_ops.reshape(log_probs, [batch_size, -1]) + log_probs = math_ops.reduce_sum(log_probs, 1) + grouping_actions = actions + grouping_log_probs = log_probs + return grouping_actions, grouping_log_probs + + def create_group_embeddings(self, grouping_actions, verbose=False): + """Approximating the blocks of a TF graph from a graph_def. + + Args: + grouping_actions: grouping predictions + verbose: print stuffs. + + Returns: + groups: list of groups. + """ + if verbose: + print("Processing input_graph") + + # TODO(azalia): Build inter-adjacencies dag matrix. + # record dag_matrix + dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32) + for op in self.important_ops: + topo_op_index = self.name_to_topo_order_index[op.name] + # TODO(agoldie) child_id + group_index = grouping_actions[0][topo_op_index] + for output_op in self.get_node_fanout(op): + if output_op.name not in self.important_op_names: + continue + output_group_index = grouping_actions[0][self.name_to_topo_order_index[ + output_op.name]] + dag_matrix[group_index, output_group_index] += 1.0 + num_connections = np.sum(dag_matrix) + num_intra_group_connections = dag_matrix.trace() + num_inter_group_connections = num_connections - num_intra_group_connections + if verbose: + print("grouping evaluation metric") + print(("num_connections={} num_intra_group_connections={} " + "num_inter_group_connections={}").format( + num_connections, num_intra_group_connections, + num_inter_group_connections)) + self.dag_matrix = dag_matrix + + # output_shape + op_output_shapes = np.zeros( + [ + len(self.important_ops), + self.hparams.max_num_outputs * self.hparams.max_output_size + ], + dtype=np.float32) + + for idx, op in enumerate(self.important_ops): + for i, output_properties in enumerate(self.node_properties[op.name]): + if output_properties.shape.__str__() == "": + continue + if i > self.hparams.max_num_outputs: + break + shape = output_properties.shape + for j, dim in enumerate(shape.dim): + if dim.size > 0: + k = i * self.hparams.max_output_size + j + if k >= self.hparams.max_num_outputs * self.hparams.max_output_size: + break + op_output_shapes[idx, k] = dim.size + + # group_embedding + group_embedding = np.zeros( + [ + self.num_groups, len(self.type_dict) + + self.hparams.max_num_outputs * self.hparams.max_output_size + ], + dtype=np.float32) + for op_index, op in enumerate(self.important_ops): + group_index = grouping_actions[0][self.name_to_topo_order_index[op.name]] + type_name = str(op.op) + type_index = self.type_dict[type_name] + group_embedding[group_index, type_index] += 1 + group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams. + max_output_size] += ( + op_output_shapes[op_index]) + grouping_adjacencies = np.concatenate( + [dag_matrix, np.transpose(dag_matrix)], axis=1) + group_embedding = np.concatenate( + [grouping_adjacencies, group_embedding], axis=1) + group_normalizer = np.amax(group_embedding, axis=1, keepdims=True) + group_embedding /= (group_normalizer + 1.0) + if verbose: + print("Finished Processing Input Graph") + return group_embedding + + def get_placements(self, *args, **kwargs): + num_children = self.hparams.num_children + with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): + actions_cache = variable_scope.get_local_variable( + "actions_cache", + initializer=init_ops.zeros_initializer, + dtype=dtypes.int32, + shape=[num_children, self.num_groups], + trainable=False) + + x = array_ops.tile(self.seq2seq_input_layer, [num_children, 1, 1]) + last_c, last_h, attn_mem = self.encode(x) + actions, log_probs = {}, {} + actions["sample"], log_probs["sample"] = ( + self.decode( + x, last_c, last_h, attn_mem, mode="sample")) + actions["target"], log_probs["target"] = ( + self.decode( + x, + last_c, + last_h, + attn_mem, + mode="target", + y=actions_cache)) + actions["greedy"], log_probs["greedy"] = ( + self.decode( + x, last_c, last_h, attn_mem, mode="greedy")) + actions["sample"] = control_flow_ops.cond( + self.global_step < self.hparams.stop_sampling, + lambda: state_ops.assign(actions_cache, actions["sample"]), + lambda: state_ops.assign(actions_cache, actions["target"])) + self.actions_cache = actions_cache + + return actions, log_probs + + def encode(self, x): + """Encoder using LSTM. + + Args: + x: tensor of size [num_children, num_groups, embedding_size] + + Returns: + last_c, last_h: tensors of size [num_children, hidden_size], the final + LSTM states + attn_mem: tensor of size [num_children, num_groups, hidden_size], the + attention + memory, i.e. concatenation of all hidden states, linearly transformed by + an attention matrix attn_w_1 + """ + if self.hparams.bi_lstm: + with variable_scope.variable_scope(self.hparams.name, reuse=True): + w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward") + w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward") + forget_bias = variable_scope.get_variable("encoder_forget_bias") + attn_w_1 = variable_scope.get_variable("attn_w_1") + else: + with variable_scope.variable_scope(self.hparams.name, reuse=True): + w_lstm = variable_scope.get_variable("encoder_lstm") + forget_bias = variable_scope.get_variable("encoder_forget_bias") + attn_w_1 = variable_scope.get_variable("attn_w_1") + + embedding_size = array_ops.shape(x)[2] + + signals = array_ops.split(x, self.num_groups, axis=1) + for i in range(len(signals)): + signals[i] = array_ops.reshape( + signals[i], [self.hparams.num_children, embedding_size]) + + if self.hparams.bi_lstm: + + def body(i, prev_c_forward, prev_h_forward, prev_c_backward, + prev_h_backward): + """while loop for LSTM.""" + signal_forward = signals[i] + next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward, + prev_h_forward, w_lstm_forward, + forget_bias) + + signal_backward = signals[self.num_groups - 1 - i] + next_c_backward, next_h_backward = lstm( + signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward, + forget_bias) + + next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1) + all_h.append(next_h) + + return (next_c_forward, next_h_forward, next_c_backward, + next_h_backward) + + c_forward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + h_forward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + + c_backward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + h_backward = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size / 2], + dtype=dtypes.float32) + all_h = [] + + for i in range(0, self.num_groups): + c_forward, h_forward, c_backward, h_backward = body( + i, c_forward, h_forward, c_backward, h_backward) + + last_c = array_ops.concat([c_forward, c_backward], axis=1) + last_h = array_ops.concat([h_forward, h_backward], axis=1) + attn_mem = array_ops.stack(all_h) + + else: + + def body(i, prev_c, prev_h): + signal = signals[i] + next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) + all_h.append(next_h) + return next_c, next_h + + c = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size], + dtype=dtypes.float32) + h = array_ops.zeros( + [self.hparams.num_children, self.hparams.hidden_size], + dtype=dtypes.float32) + all_h = [] + + for i in range(0, self.num_groups): + c, h = body(i, c, h) + + last_c = c + last_h = h + attn_mem = array_ops.stack(all_h) + + attn_mem = array_ops.transpose(attn_mem, [1, 0, 2]) + attn_mem = array_ops.reshape( + attn_mem, + [self.hparams.num_children * self.num_groups, self.hparams.hidden_size]) + attn_mem = math_ops.matmul(attn_mem, attn_w_1) + attn_mem = array_ops.reshape( + attn_mem, + [self.hparams.num_children, self.num_groups, self.hparams.hidden_size]) + + return last_c, last_h, attn_mem + + def decode(self, + x, + last_c, + last_h, + attn_mem, + mode="target", + y=None): + """Decoder using LSTM. + + Args: + x: tensor of size [num_children, num_groups, embedding_size]. + last_c: tensor of size [num_children, hidden_size], the final LSTM states + computed by self.encoder. + last_h: same as last_c. + attn_mem: tensor of size [num_children, num_groups, hidden_size]. + mode: "target" or "sample". + y: tensor of size [num_children, num_groups], the device placements. + + Returns: + actions: tensor of size [num_children, num_groups], the placements of + devices + """ + with variable_scope.variable_scope(self.hparams.name, reuse=True): + w_lstm = variable_scope.get_variable("decoder_lstm") + forget_bias = variable_scope.get_variable("decoder_forget_bias") + device_embeddings = variable_scope.get_variable("device_embeddings") + device_softmax = variable_scope.get_variable("device_softmax") + device_go_embedding = variable_scope.get_variable("device_go_embedding") + attn_w_2 = variable_scope.get_variable("attn_w_2") + attn_v = variable_scope.get_variable("attn_v") + + actions = tensor_array_ops.TensorArray( + dtypes.int32, + size=self.num_groups, + infer_shape=False, + clear_after_read=False) + + # pylint: disable=unused-argument + def condition(i, *args): + return math_ops.less(i, self.num_groups) + + # pylint: disable=missing-docstring + def body(i, prev_c, prev_h, actions, log_probs): + # pylint: disable=g-long-lambda + signal = control_flow_ops.cond( + math_ops.equal(i, 0), + lambda: array_ops.tile(device_go_embedding, + [self.hparams.num_children, 1]), + lambda: embedding_ops.embedding_lookup(device_embeddings, + actions.read(i - 1)) + ) + if self.hparams.keep_prob is not None: + signal = nn_ops.dropout(signal, self.hparams.keep_prob) + next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) + query = math_ops.matmul(next_h, attn_w_2) + query = array_ops.reshape( + query, [self.hparams.num_children, 1, self.hparams.hidden_size]) + query = math_ops.tanh(query + attn_mem) + query = array_ops.reshape(query, [ + self.hparams.num_children * self.num_groups, self.hparams.hidden_size + ]) + query = math_ops.matmul(query, attn_v) + query = array_ops.reshape(query, + [self.hparams.num_children, self.num_groups]) + query = nn_ops.softmax(query) + query = array_ops.reshape(query, + [self.hparams.num_children, self.num_groups, 1]) + query = math_ops.reduce_sum(attn_mem * query, axis=1) + query = array_ops.concat([next_h, query], axis=1) + logits = math_ops.matmul(query, device_softmax) + logits /= self.hparams.temperature + if self.hparams.tanh_constant > 0: + logits = math_ops.tanh(logits) * self.hparams.tanh_constant + if self.hparams.logits_std_noise > 0: + num_in_logits = math_ops.cast( + array_ops.size(logits), dtype=dtypes.float32) + avg_norm = math_ops.divide( + linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) + logits_noise = random_ops.random_normal( + array_ops.shape(logits), + stddev=self.hparams.logits_std_noise * avg_norm) + logits = control_flow_ops.cond( + self.global_step > self.hparams.stop_noise_step, lambda: logits, + lambda: logits + logits_noise) + + if mode == "sample": + next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed) + elif mode == "greedy": + next_y = math_ops.argmax(logits, 1) + elif mode == "target": + next_y = array_ops.slice(y, [0, i], [-1, 1]) + else: + raise NotImplementedError + next_y = math_ops.to_int32(next_y) + next_y = array_ops.reshape(next_y, [self.hparams.num_children]) + actions = actions.write(i, next_y) + log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits( + logits=logits, labels=next_y) + return i + 1, next_c, next_h, actions, log_probs + + loop_vars = [ + constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions, + array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32) + ] + loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars) + + last_c = loop_outputs[-4] + last_h = loop_outputs[-3] + actions = loop_outputs[-2].stack() + actions = array_ops.transpose(actions, [1, 0]) + log_probs = loop_outputs[-1] + return actions, log_probs + + def eval_placement(self, + sess, + child_id=0, + verbose=False): + grouping_actions, actions = sess.run([ + self.grouping_actions_cache, + self.actions_cache + ]) + grouping_actions = grouping_actions[child_id] + actions = actions[child_id] + if verbose: + global_step = sess.run(self.global_step) + if global_step % 100 == 0: + log_string = "op group assignments: " + for a in grouping_actions: + log_string += "{} ".format(a) + print(log_string[:-1]) + log_string = "group device assignments: " + for a in actions: + log_string += "{} ".format(a) + print(log_string[:-1]) + + for op in self.important_ops: + topo_order_index = self.name_to_topo_order_index[op.name] + group_index = grouping_actions[topo_order_index] + op.device = self.devices[actions[group_index]].name + try: + _, run_time, _ = self.cluster.MeasureCosts(self.item) + except errors.ResourceExhaustedError: + run_time = self.hparams.failing_signal + return run_time + + def update_reward(self, + sess, + run_time, + child_id=0, + verbose=False): + reward = self.compute_reward(run_time) + controller_ops = self.ops["controller"] + _, best_reward = sess.run( + [ + controller_ops["reward"]["update"][child_id], + controller_ops["best_reward"]["update"][child_id] + ], + feed_dict={ + controller_ops["reward"]["ph"][child_id]: reward, + }) + if verbose: + print(("run_time={:<.5f} reward={:<.5f} " + "best_reward={:<.5f}").format(run_time, reward, best_reward)) + + # Reward is a double, best_reward a float: allow for some slack in the + # comparison. + updated = abs(best_reward - reward) < 1e-6 + return updated + + def generate_grouping(self, sess): + controller_ops = self.ops["controller"] + grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"]) + return grouping_actions + + def generate_placement(self, grouping, sess): + controller_ops = self.ops["controller"] + feed_seq2seq_input_dict = {} + feed_seq2seq_input_dict[self.seq2seq_input_layer] = np.expand_dims( + grouping, axis=0) + sess.run( + controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict) + + def process_reward(self, sess): + controller_ops = self.ops["controller"] + run_ops = [ + controller_ops["loss"], controller_ops["lr"], + controller_ops["grad_norm"], controller_ops["grad_norms"], + controller_ops["train_op"] + ] + sess.run(run_ops) + sess.run(controller_ops["baseline_update"]) + + def _get_train_ops(self, + loss, + tf_variables, + global_step, + grad_bound=1.25, + lr_init=1e-3, + lr_dec=0.9, + start_decay_step=10000, + decay_steps=100, + optimizer_type="adam"): + """Loss optimizer. + + Args: + loss: scalar tf tensor + tf_variables: list of training variables, typically + tf.trainable_variables() + global_step: global_step + grad_bound: max gradient norm + lr_init: initial learning rate + lr_dec: leaning rate decay coefficient + start_decay_step: start decaying learning rate after this many steps + decay_steps: apply decay rate factor at this step intervals + optimizer_type: optimizer type should be either adam or sgd + + Returns: + train_op: training op + learning_rate: scalar learning rate tensor + grad_norm: l2 norm of the gradient vector + all_grad_norms: l2 norm of each component + """ + lr_gstep = global_step - start_decay_step + + def f1(): + return constant_op.constant(lr_init) + + def f2(): + return learning_rate_decay.exponential_decay(lr_init, lr_gstep, + decay_steps, lr_dec, True) + + learning_rate = control_flow_ops.cond( + math_ops.less(global_step, start_decay_step), + f1, + f2, + name="learning_rate") + + if optimizer_type == "adam": + opt = adam.AdamOptimizer(learning_rate) + elif optimizer_type == "sgd": + opt = gradient_descent.GradientDescentOptimizer(learning_rate) + grads_and_vars = opt.compute_gradients(loss, tf_variables) + grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars]) + all_grad_norms = {} + clipped_grads = [] + clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0) + for g, v in grads_and_vars: + if g is not None: + if isinstance(g, tf_ops.IndexedSlices): + clipped = g.values / clipped_rate + norm_square = math_ops.reduce_sum(clipped * clipped) + clipped = tf_ops.IndexedSlices(clipped, g.indices) + else: + clipped = g / clipped_rate + norm_square = math_ops.reduce_sum(clipped * clipped) + all_grad_norms[v.name] = math_ops.sqrt(norm_square) + clipped_grads.append((clipped, v)) + + train_op = opt.apply_gradients(clipped_grads, global_step) + return train_op, learning_rate, grad_norm, all_grad_norms + + +def lstm(x, prev_c, prev_h, w_lstm, forget_bias): + """LSTM cell. + + Args: + x: tensors of size [num_children, hidden_size]. + prev_c: tensors of size [num_children, hidden_size]. + prev_h: same as prev_c. + w_lstm: . + forget_bias: . + + Returns: + next_c: + next_h: + """ + ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm) + i, f, o, g = array_ops.split(ifog, 4, axis=1) + i = math_ops.sigmoid(i) + f = math_ops.sigmoid(f + forget_bias) + o = math_ops.sigmoid(o) + g = math_ops.tanh(g) + next_c = i * g + f * prev_c + next_h = o * math_ops.tanh(next_c) + return next_c, next_h diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i index d0fc1a04f220e0a053257e0206bb07b25f3767c6..9a84c60b04029a64ed35a01f045a6eec5e492504 100644 --- a/tensorflow/python/grappler/item.i +++ b/tensorflow/python/grappler/item.i @@ -96,10 +96,10 @@ static GItem TF_NewItem( return GItem(item.release()); } -static std::vector TF_IdentifyImportantOps(GItem item, bool sort_topologically, +static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically, TF_Status* status) { if (item.is_none()) { - return {}; + Py_RETURN_NONE; } std::vector main_ops = item->MainOpsFanin(); @@ -132,7 +132,13 @@ static std::vector TF_IdentifyImportantOps(GItem item, bool sort_topolog } } - return ops; + PyGILState_STATE gstate = PyGILState_Ensure(); + PyObject* result = PyList_New(ops.size()); + for (int i = 0; i < ops.size(); ++i) { + PyList_SetItem(result, i, PyString_FromString(ops[i].c_str())); + } + PyGILState_Release(gstate); + return result; } static PyObject* TF_GetOpProperties(GItem item) { @@ -305,7 +311,7 @@ static PyObject* TF_GetColocationGroups(GItem item) { static GItem TF_NewItem( const tensorflow::MetaGraphDef& meta_graph, bool ignore_colocation, bool ignore_user_placement, TF_Status* out_status); -static std::vector TF_IdentifyImportantOps(GItem item, bool sort_topologically, - TF_Status* status); +static PyObject* TF_IdentifyImportantOps(GItem item, bool sort_topologically, + TF_Status* status); static PyObject* TF_GetOpProperties(GItem item); static PyObject* TF_GetColocationGroups(GItem item); diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py index cd70e2fdecc74f9d99240ac566f3c28e900a06c2..7c3efd6249cbdaa2675632f7fc8e25fb88658a24 100644 --- a/tensorflow/python/grappler/item_test.py +++ b/tensorflow/python/grappler/item_test.py @@ -56,7 +56,7 @@ class ItemTest(test.TestCase): mg = meta_graph.create_meta_graph_def(graph=g) grappler_item = item.Item(mg) op_list = grappler_item.IdentifyImportantOps() - self.assertItemsEqual([b'Const', b'Const_1', b'add'], op_list) + self.assertItemsEqual(['Const', 'Const_1', 'add'], op_list) def testOpProperties(self): with ops.Graph().as_default() as g: diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py index 30dcdf31aadd9effc96a7df751c127570f1fb8d8..0f5150174049250e86bbac0a49eb998339058326 100644 --- a/tensorflow/python/grappler/layout_optimizer_test.py +++ b/tensorflow/python/grappler/layout_optimizer_test.py @@ -256,7 +256,7 @@ class LayoutOptimizerTest(test.TestCase): x = random_ops.truncated_normal([1, 784], seed=0) output = _two_layer_model(x) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -293,7 +293,7 @@ class LayoutOptimizerTest(test.TestCase): add = bn0[0] + bn1[0] output = array_ops.identity(add) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={dim: 3}) with session.Session(config=_get_config()) as sess: @@ -325,7 +325,7 @@ class LayoutOptimizerTest(test.TestCase): value=conv, size_splits=sizes, axis=dim, num_split=3) output = math_ops.reduce_sum(split[0]) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={dim: 3}) with session.Session(config=_get_config()) as sess: @@ -359,7 +359,7 @@ class LayoutOptimizerTest(test.TestCase): pad = array_ops.pad(conv, paddings) output = array_ops.identity(pad) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -390,7 +390,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -419,7 +419,7 @@ class LayoutOptimizerTest(test.TestCase): cast = math_ops.cast(conv, dtype='bool') output = array_ops.identity(cast) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -450,7 +450,67 @@ class LayoutOptimizerTest(test.TestCase): squeeze = array_ops.squeeze(reduce_sum) output = array_ops.identity(squeeze) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Three transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 1 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testSqueezeAlongHW(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2], keep_dims=True) + squeeze = array_ops.squeeze(reduce_sum, axis=[1, 2]) + output = array_ops.identity(squeeze) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Three transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 1 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testSqueezeAlongNHW(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2], keep_dims=True) + squeeze = array_ops.squeeze(reduce_sum, axis=[0, 1, 2]) + output = array_ops.identity(squeeze) + + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -479,7 +539,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv, axis=[1, 2, 3]) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -508,7 +568,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv, axis=[0, 1, 2]) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -537,7 +597,7 @@ class LayoutOptimizerTest(test.TestCase): reduce_sum = math_ops.reduce_sum(conv, axis=[3]) output = array_ops.identity(reduce_sum) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -558,6 +618,94 @@ class LayoutOptimizerTest(test.TestCase): self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testReduceSumAlongCKeepDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[3], keep_dims=True) + output = array_ops.identity(reduce_sum) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self._assert_trans_nchw_to_nhwc('Sum-0-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testReduceSumAlongHKeepDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[2], keep_dims=True) + output = array_ops.identity(reduce_sum) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + + def testReduceSumAlongWCKeepDims(self): + if test.is_gpu_available(cuda_only=True): + random_seed.set_random_seed(0) + x = random_ops.truncated_normal([1, 784], seed=0) + conv = _two_layer_model(x) + reduce_sum = math_ops.reduce_sum(conv, axis=[2, 3], keep_dims=True) + output = array_ops.identity(reduce_sum) + + with session.Session(config=_get_config(False)) as sess: + output_val_ref = sess.run(output) + + with session.Session(config=_get_config()) as sess: + metadata = config_pb2.RunMetadata() + output_val = sess.run(output, run_metadata=metadata) + + nodes = [] + num_transposes = 0 + for node in metadata.cost_graph.node: + if _is_transpose(node.name): + num_transposes += 1 + nodes.append(node.name) + + # Four transposes were initially added in the Expand phase of + # LayoutOptimizer; two of them are cancelled out in the Collapse phase. + expected_num_transposes = 2 + self.assertEqual(expected_num_transposes, num_transposes) + self._assert_trans_nhwc_to_nchw('Conv2D-0', nodes) + self.assertAllClose(output_val_ref, output_val, atol=1e-3) + def testConcatWithControlDependency(self): if test.is_gpu_available(cuda_only=True): random_seed.set_random_seed(0) @@ -570,7 +718,7 @@ class LayoutOptimizerTest(test.TestCase): concat = array_ops.concat([conv, conv], axis) output = array_ops.identity(concat) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -604,7 +752,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(fill) x_val = [3.4] * 784 - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={x: x_val}) with session.Session(config=_get_config()) as sess: @@ -646,7 +794,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(tile) multiple_val = [2, 3, 4, 1] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={multiple: multiple_val}) with session.Session(config=_get_config()) as sess: @@ -681,7 +829,7 @@ class LayoutOptimizerTest(test.TestCase): reverse = array_ops.reverse(conv, dims) output = array_ops.identity(reverse) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -714,7 +862,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(reverse) dims_val = [2, 3] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={dims: dims_val}) with session.Session(config=_get_config()) as sess: @@ -751,7 +899,7 @@ class LayoutOptimizerTest(test.TestCase): select = gen_math_ops._select(condition, conv, add) output = array_ops.identity(select) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -782,7 +930,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(select) condition_val = np.zeros((1, 7, 7, 64)) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={condition: condition_val}) with session.Session(config=_get_config()) as sess: @@ -812,7 +960,7 @@ class LayoutOptimizerTest(test.TestCase): select = gen_math_ops._select(condition, conv, add) output = array_ops.identity(select) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -842,7 +990,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(pad) paddings_val = [[1, 2], [3, 4], [5, 6], [7, 8]] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={paddings: paddings_val}) with session.Session(config=_get_config()) as sess: @@ -879,7 +1027,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(max_pool) strides_val = [1, 3, 2, 1] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={strides: strides_val}) with session.Session(config=_get_config()) as sess: @@ -916,7 +1064,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(max_pool_grad) strides_val = [1, 3, 2, 1] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={strides: strides_val}) with session.Session(config=_get_config()) as sess: @@ -951,7 +1099,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(s) size_val = [1, 2, 3, 4] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={size: size_val}) with session.Session(config=_get_config()) as sess: @@ -987,7 +1135,7 @@ class LayoutOptimizerTest(test.TestCase): output = array_ops.identity(s) end_val = [1, 2, 3, 4] - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={end: end_val}) with session.Session(config=_get_config()) as sess: @@ -1025,7 +1173,7 @@ class LayoutOptimizerTest(test.TestCase): s = conv[:, :, 1:-1, :] output = array_ops.identity(s) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1060,7 +1208,7 @@ class LayoutOptimizerTest(test.TestCase): s = conv[:, :, :, 1:-1] output = array_ops.identity(s) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1099,7 +1247,7 @@ class LayoutOptimizerTest(test.TestCase): [1, 2, 3, 1], s) output = array_ops.identity(s_grad) - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={end: end_val}) with session.Session(config=_get_config()) as sess: @@ -1135,7 +1283,7 @@ class LayoutOptimizerTest(test.TestCase): output = math_ops.add(shapen[0], shapen[1]) x_val = [1.7] * 784 - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={x: x_val}) with session.Session(config=_get_config()) as sess: @@ -1169,7 +1317,7 @@ class LayoutOptimizerTest(test.TestCase): output = math_ops.add_n([conv_reshape, ones]) x_val = [1.7] * 784 - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output, feed_dict={x: x_val}) with session.Session(config=_get_config()) as sess: @@ -1193,7 +1341,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _loop() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1220,7 +1368,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _loop_with_branch() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1244,7 +1392,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _loop_with_vec_and_4d() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: @@ -1268,7 +1416,7 @@ class LayoutOptimizerTest(test.TestCase): if test.is_gpu_available(cuda_only=True): output = _model_with_second_port() - with session.Session() as sess: + with session.Session(config=_get_config(False)) as sess: output_val_ref = sess.run(output) with session.Session(config=_get_config()) as sess: diff --git a/tensorflow/python/grappler/tf_optimizer.i b/tensorflow/python/grappler/tf_optimizer.i index 1b657983a4690dd0ddb7f569ce514b08cb10400a..de9326ccfc1653c2afd0833dcdca2cc4bfdabed5 100644 --- a/tensorflow/python/grappler/tf_optimizer.i +++ b/tensorflow/python/grappler/tf_optimizer.i @@ -100,6 +100,7 @@ PyObject* TF_OptimizeGraph( tensorflow::grappler::ItemConfig item_config; item_config.inline_functions = false; item_config.apply_optimizations = false; + item_config.ignore_user_placement = false; std::unique_ptr grappler_item = tensorflow::grappler::GrapplerItemFromMetaGraphDef(graph_id, metagraph, item_config); diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py index 55dcbe2071f74204e0bbdd141560f33cefdf174d..3ee4d7807ea5677a742514eb56267b94c6b92bba 100644 --- a/tensorflow/python/grappler/tf_optimizer_test.py +++ b/tensorflow/python/grappler/tf_optimizer_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import meta_graph from tensorflow.python.framework import ops from tensorflow.python.grappler import tf_optimizer from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -48,6 +49,31 @@ class PyWrapOptimizeGraphTest(test.TestCase): self.assertEqual(len(graph.node), 1) self.assertItemsEqual([node.name for node in graph.node], ['d']) + def testKeepNodes(self): + g = ops.Graph() + with g.as_default(): + a1 = variables.Variable( + 1.0) # Must be preserved since it's in the collection 'variables'. + a2 = constant_op.constant(0, shape=[50, 50], name='keep') + ops.add_to_collection('a2', a2) # Explicitly add to collection. + b = constant_op.constant(1, shape=[100, 10]) + c = constant_op.constant(0, shape=[10, 30]) + d = math_ops.matmul(b, c) + ops.add_to_collection('train_op', d) # d is the fetch node. + + # Optimize the graph. + mg = meta_graph.create_meta_graph_def(graph=g) + rewriter_config = rewriter_config_pb2.RewriterConfig() + optimized_graph = tf_optimizer.OptimizeGraph(rewriter_config, mg) + + # Check that the nodes referenced in various collections have been preserved + self.assertEqual(len(optimized_graph.node), 5) + self.assertEqual(d.op.name, optimized_graph.node[0].name) + self.assertEqual(a1.op.name, optimized_graph.node[1].name) + self.assertEqual('Variable/initial_value', optimized_graph.node[2].name) + self.assertEqual(a2.op.name, optimized_graph.node[3].name) + self.assertEqual('Variable/Assign', optimized_graph.node[4].name) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD index fdac22bb53cc7e78d854d4b5ff756a190c9c62b6..a98d08f92892cd5f923833a2059ce7e89ebba1aa 100755 --- a/tensorflow/python/keras/BUILD +++ b/tensorflow/python/keras/BUILD @@ -3,6 +3,8 @@ licenses(["notice"]) # Apache 2.0 +exports_files(["LICENSE"]) + package(default_visibility = ["//visibility:public"]) load("//tensorflow:tensorflow.bzl", "py_test") @@ -37,7 +39,11 @@ py_library( "_impl/keras/datasets/mnist.py", "_impl/keras/datasets/reuters.py", "_impl/keras/engine/__init__.py", - "_impl/keras/engine/topology.py", + "_impl/keras/engine/base_layer.py", + "_impl/keras/engine/input_layer.py", + "_impl/keras/engine/network.py", + "_impl/keras/engine/saving.py", + "_impl/keras/engine/sequential.py", "_impl/keras/engine/training.py", "_impl/keras/engine/training_eager.py", "_impl/keras/estimator.py", @@ -254,6 +260,11 @@ py_test( size = "small", srcs = ["_impl/keras/metrics_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", + "no_oss", + "notap", + ], deps = [ ":keras", "//tensorflow/python:client_testlib", @@ -393,7 +404,7 @@ py_test( py_test( name = "convolutional_test", - size = "medium", + size = "large", srcs = ["_impl/keras/layers/convolutional_test.py"], srcs_version = "PY2AND3", tags = [ @@ -734,6 +745,19 @@ py_test( ], ) +py_test( + name = "model_subclassing_test", + size = "medium", + srcs = ["_impl/keras/model_subclassing_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + py_test( name = "topology_test", size = "small", @@ -741,9 +765,31 @@ py_test( srcs_version = "PY2AND3", deps = [ ":keras", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", + "//third_party/py/numpy", + ], +) + +py_test( + name = "saving_test", + size = "small", + srcs = ["_impl/keras/engine/saving_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sequential_test", + size = "small", + srcs = ["_impl/keras/engine/sequential_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":keras", + "//tensorflow/python:client_testlib", "//third_party/py/numpy", ], ) @@ -764,7 +810,7 @@ py_test( py_test( name = "estimator_test", - size = "medium", + size = "large", srcs = ["_impl/keras/estimator_test.py"], srcs_version = "PY2AND3", tags = ["notsan"], diff --git a/tensorflow/python/keras/_impl/keras/__init__.py b/tensorflow/python/keras/_impl/keras/__init__.py index 73113539329c5493141db243b85254062f7b8f88..b63907b2e60acfc80ee411b9193b2829f0224c3e 100644 --- a/tensorflow/python/keras/_impl/keras/__init__.py +++ b/tensorflow/python/keras/_impl/keras/__init__.py @@ -40,4 +40,4 @@ from tensorflow.python.keras._impl.keras.layers import Input from tensorflow.python.keras._impl.keras.models import Model from tensorflow.python.keras._impl.keras.models import Sequential -__version__ = '2.1.3-tf' +__version__ = '2.1.4-tf' diff --git a/tensorflow/python/keras/_impl/keras/applications/densenet.py b/tensorflow/python/keras/_impl/keras/applications/densenet.py index 6521f8410435fd13393b9991d3ee9a6342a912d0..ca83e8691237216e799f2ca738dcb6822506e2cb 100644 --- a/tensorflow/python/keras/_impl/keras/applications/densenet.py +++ b/tensorflow/python/keras/_impl/keras/applications/densenet.py @@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py index d9cb726137409f899bc75e3c19f6bffeb3ca4e34..c26a28ed4087e30968585ec8ac0b64b51513bcae 100644 --- a/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py +++ b/tensorflow/python/keras/_impl/keras/applications/imagenet_utils.py @@ -234,7 +234,8 @@ def decode_predictions(preds, top=5): CLASS_INDEX_PATH, cache_subdir='models', file_hash='c2c37ea517e94d9795004a39431a14cb') - CLASS_INDEX = json.load(open(fpath)) + with open(fpath) as f: + CLASS_INDEX = json.load(f) results = [] for pred in preds: top_indices = pred.argsort()[-top:][::-1] diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py index bf3901fc54419c2b401bf9c4d6311b39a18f1aba..17e407dd58460e6d6802a3e137a96faf38a6f576 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py +++ b/tensorflow/python/keras/_impl/keras/applications/inception_resnet_v2.py @@ -31,7 +31,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py index e268e97bc663773a218f01b958b08f8e43c74ee2..2897c6058eb445ceacc34084b53dc89f556e3e9c 100644 --- a/tensorflow/python/keras/_impl/keras/applications/inception_v3.py +++ b/tensorflow/python/keras/_impl/keras/applications/inception_v3.py @@ -37,7 +37,7 @@ from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py index 027ae26113a42782fbbee27d993b85cb3aebbf23..ad96b53a4528d99a014a0214b52a78d6a60076f8 100644 --- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py +++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py @@ -79,8 +79,8 @@ from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization from tensorflow.python.keras._impl.keras.layers import Conv2D @@ -561,7 +561,7 @@ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)): and width and height should be no smaller than 32. E.g. `(224, 224, 3)` would be one valid value. filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). alpha: controls the width of the network. - If `alpha` < 1.0, proportionally decreases the number of filters in each layer. @@ -627,7 +627,7 @@ def _depthwise_conv_block(inputs, (with `channels_last` data format) or (channels, rows, cols) (with `channels_first` data format). pointwise_conv_filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the pointwise convolution). + (i.e. the number of output filters in the pointwise convolution). alpha: controls the width of the network. - If `alpha` < 1.0, proportionally decreases the number of filters in each layer. diff --git a/tensorflow/python/keras/_impl/keras/applications/nasnet.py b/tensorflow/python/keras/_impl/keras/applications/nasnet.py index 08dae57f006c64021cbca26404770cd89b1ce176..dd33230a7eb9272f8fc60daee63e1f92574cf5e3 100644 --- a/tensorflow/python/keras/_impl/keras/applications/nasnet.py +++ b/tensorflow/python/keras/_impl/keras/applications/nasnet.py @@ -49,7 +49,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.inception_v3 import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import add from tensorflow.python.keras._impl.keras.layers import AveragePooling2D diff --git a/tensorflow/python/keras/_impl/keras/applications/resnet50.py b/tensorflow/python/keras/_impl/keras/applications/resnet50.py index a47dd657bb9ea0627d82831b7ee5d0b33788b5b7..46c0e635578c7f4707b027247943d75b16d703ad 100644 --- a/tensorflow/python/keras/_impl/keras/applications/resnet50.py +++ b/tensorflow/python/keras/_impl/keras/applications/resnet50.py @@ -34,7 +34,7 @@ from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import AveragePooling2D from tensorflow.python.keras._impl.keras.layers import BatchNormalization diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg16.py b/tensorflow/python/keras/_impl/keras/applications/vgg16.py index 9da74253abc2124844ab89b7727ddda4f754d8e2..cefb25063e30505c9c34b49fd2df6eb7210d7ca8 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg16.py +++ b/tensorflow/python/keras/_impl/keras/applications/vgg16.py @@ -32,7 +32,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Conv2D from tensorflow.python.keras._impl.keras.layers import Dense from tensorflow.python.keras._impl.keras.layers import Flatten diff --git a/tensorflow/python/keras/_impl/keras/applications/vgg19.py b/tensorflow/python/keras/_impl/keras/applications/vgg19.py index 961c1f991893dbc0df858e9f72b61202c9fee500..dadaf4fdf0cc5922752c6867720c5d8cdbcab19a 100644 --- a/tensorflow/python/keras/_impl/keras/applications/vgg19.py +++ b/tensorflow/python/keras/_impl/keras/applications/vgg19.py @@ -32,7 +32,7 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions from tensorflow.python.keras._impl.keras.applications.imagenet_utils import preprocess_input -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Conv2D from tensorflow.python.keras._impl.keras.layers import Dense from tensorflow.python.keras._impl.keras.layers import Flatten diff --git a/tensorflow/python/keras/_impl/keras/applications/xception.py b/tensorflow/python/keras/_impl/keras/applications/xception.py index 7e7ca5a18a31622ac79d61ab01ce65341a4a46c5..971063a16d1f5ba0e25189f1ef2f6c24eb5f8d61 100644 --- a/tensorflow/python/keras/_impl/keras/applications/xception.py +++ b/tensorflow/python/keras/_impl/keras/applications/xception.py @@ -44,7 +44,7 @@ from tensorflow.python.keras._impl.keras import layers from tensorflow.python.keras._impl.keras.applications import imagenet_utils from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs from tensorflow.python.keras._impl.keras.layers import Activation from tensorflow.python.keras._impl.keras.layers import BatchNormalization from tensorflow.python.keras._impl.keras.layers import Conv2D diff --git a/tensorflow/python/keras/_impl/keras/backend.py b/tensorflow/python/keras/_impl/keras/backend.py index 6988e0332f057cfa635efab4611ecd33368908a0..2b75666b9e61baea635a312c005fdbd955f6cab6 100644 --- a/tensorflow/python/keras/_impl/keras/backend.py +++ b/tensorflow/python/keras/_impl/keras/backend.py @@ -258,7 +258,7 @@ def set_image_data_format(data_format): """ global _IMAGE_DATA_FORMAT if data_format not in {'channels_last', 'channels_first'}: - raise ValueError('Unknown data_format:', data_format) + raise ValueError('Unknown data_format: ' + str(data_format)) _IMAGE_DATA_FORMAT = str(data_format) @@ -342,13 +342,11 @@ def learning_phase(): Returns: Learning phase (scalar integer tensor or Python integer). - - Raises: - ValueError: If called when Eager execution is enabled. """ if context.in_eager_mode(): if 'eager' not in _GRAPH_LEARNING_PHASES: - raise ValueError('No learning phase set in Eager mode.') + # Fallback to inference mode as default. + return 0 return _GRAPH_LEARNING_PHASES['eager'] graph = ops.get_default_graph() @@ -488,7 +486,7 @@ def _get_available_gpus(): def _has_nchw_support(): """Check whether the current scope supports NCHW ops. - Tensorflow does not support NCHW on CPU. Therefore we check if we are not + TensorFlow does not support NCHW on CPU. Therefore we check if we are not explicitly put on CPU, and have GPUs available. In this case there will be soft-placing on the GPU device. @@ -2232,7 +2230,7 @@ def resize_images(x, height_factor, width_factor, data_format): if original_shape[2] is not None else None, None)) return x else: - raise ValueError('Invalid data_format:', data_format) + raise ValueError('Invalid data_format: ' + str(data_format)) @tf_export('keras.backend.resize_volumes') @@ -2264,7 +2262,7 @@ def resize_volumes(x, depth_factor, height_factor, width_factor, data_format): output = repeat_elements(output, width_factor, axis=3) return output else: - raise ValueError('Invalid data_format:', data_format) + raise ValueError('Invalid data_format: ' + str(data_format)) @tf_export('keras.backend.repeat_elements') @@ -2346,7 +2344,7 @@ def arange(start, stop=None, step=1, dtype='int32'): The function arguments use the same convention as Theano's arange: if only one argument is provided, - it is in fact the "stop" argument. + it is in fact the "stop" argument and "start" is 0. The default type of the returned tensor is `'int32'` to match TensorFlow's default. @@ -2361,7 +2359,7 @@ def arange(start, stop=None, step=1, dtype='int32'): An integer tensor. """ - # Match the behavior of numpy and Theano by returning an empty seqence. + # Match the behavior of numpy and Theano by returning an empty sequence. if stop is None and start < 0: start = 0 result = math_ops.range(start, limit=stop, delta=step, name='arange') @@ -2482,7 +2480,7 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if data_format == 'channels_first': pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])] @@ -2523,7 +2521,7 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if data_format == 'channels_first': pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]], @@ -2598,6 +2596,8 @@ def get_value(x): Returns: A Numpy array. """ + if context.in_eager_mode(): + return x.numpy() return x.eval(session=get_session()) @@ -2611,6 +2611,8 @@ def batch_get_value(tensors): Returns: A list of Numpy arrays. """ + if context.in_eager_mode(): + return [x.numpy() for x in tensors] if tensors: return get_session().run(tensors) else: @@ -2627,16 +2629,19 @@ def set_value(x, value): (of the same shape). """ value = np.asarray(value, dtype=dtype(x)) - tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) - if hasattr(x, '_assign_placeholder'): - assign_placeholder = x._assign_placeholder - assign_op = x._assign_op + if context.in_eager_mode(): + x.assign(value) else: - assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape) - assign_op = x.assign(assign_placeholder) - x._assign_placeholder = assign_placeholder - x._assign_op = assign_op - get_session().run(assign_op, feed_dict={assign_placeholder: value}) + tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) + if hasattr(x, '_assign_placeholder'): + assign_placeholder = x._assign_placeholder + assign_op = x._assign_op + else: + assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape) + assign_op = x.assign(assign_placeholder) + x._assign_placeholder = assign_placeholder + x._assign_op = assign_op + get_session().run(assign_op, feed_dict={assign_placeholder: value}) @tf_export('keras.backend.batch_set_value') @@ -2647,23 +2652,28 @@ def batch_set_value(tuples): tuples: a list of tuples `(tensor, value)`. `value` should be a Numpy array. """ - if tuples: - assign_ops = [] - feed_dict = {} + if context.in_eager_mode(): for x, value in tuples: - value = np.asarray(value, dtype=dtype(x)) - tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) - if hasattr(x, '_assign_placeholder'): - assign_placeholder = x._assign_placeholder - assign_op = x._assign_op - else: - assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape) - assign_op = x.assign(assign_placeholder) - x._assign_placeholder = assign_placeholder - x._assign_op = assign_op - assign_ops.append(assign_op) - feed_dict[assign_placeholder] = value - get_session().run(assign_ops, feed_dict=feed_dict) + x.assign(np.asarray(value, dtype=dtype(x))) + else: + if tuples: + assign_ops = [] + feed_dict = {} + for x, value in tuples: + value = np.asarray(value, dtype=dtype(x)) + tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0]) + if hasattr(x, '_assign_placeholder'): + assign_placeholder = x._assign_placeholder + assign_op = x._assign_op + else: + assign_placeholder = array_ops.placeholder(tf_dtype, + shape=value.shape) + assign_op = x.assign(assign_placeholder) + x._assign_placeholder = assign_placeholder + x._assign_op = assign_op + assign_ops.append(assign_op) + feed_dict[assign_placeholder] = value + get_session().run(assign_ops, feed_dict=feed_dict) @tf_export('keras.backend.print_tensor') @@ -2739,7 +2749,7 @@ class Function(object): self.updates_op = control_flow_ops.group(*updates_ops) self.name = name # additional tensor substitutions - self.feed_dict = session_kwargs.pop('feed_dict', {}) + self.feed_dict = session_kwargs.pop('feed_dict', None) # additional operations self.fetches = session_kwargs.pop('fetches', []) if not isinstance(self.fetches, list): @@ -2749,8 +2759,15 @@ class Function(object): def __call__(self, inputs): if not isinstance(inputs, (list, tuple)): raise TypeError('`inputs` should be a list or tuple.') - feed_dict = self.feed_dict.copy() + + if self.feed_dict: + feed_dict = self.feed_dict.copy() + else: + feed_dict = {} + for tensor, value in zip(self.inputs, inputs): + if value is None: + continue if is_sparse(tensor): sparse_coo = value.tocoo() indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), @@ -2784,7 +2801,7 @@ def function(inputs, outputs, updates=None, **kwargs): for key in kwargs: if (key not in tf_inspect.getargspec(session_module.Session.run)[0] and key not in tf_inspect.getargspec(Function.__init__)[0]): - msg = ('Invalid argument "%s" passed to K.function with Tensorflow ' + msg = ('Invalid argument "%s" passed to K.function with TensorFlow ' 'backend') % key raise ValueError(msg) return Function(inputs, outputs, updates=updates, **kwargs) @@ -2904,7 +2921,7 @@ def rnn(step_function, if unroll: if not inputs.get_shape()[0]: - raise ValueError('Unrolling requires a ' 'fixed number of timesteps.') + raise ValueError('Unrolling requires a fixed number of timesteps.') states = initial_states successive_states = [] successive_outputs = [] @@ -3077,7 +3094,8 @@ def rnn(step_function, outputs_shape[1] = inputs_shape[1] outputs.set_shape(outputs_shape) - last_output._uses_learning_phase = uses_learning_phase + if not context.in_eager_mode(): + last_output._uses_learning_phase = uses_learning_phase return last_output, outputs, new_states @@ -3548,7 +3566,7 @@ def _preprocess_conv3d_input(x, data_format): def _preprocess_padding(padding): - """Convert keras' padding to tensorflow's padding. + """Convert keras' padding to TensorFlow's padding. Arguments: padding: string, one of 'same' , 'valid' @@ -3564,7 +3582,7 @@ def _preprocess_padding(padding): elif padding == 'valid': padding = 'VALID' else: - raise ValueError('Invalid padding:', padding) + raise ValueError('Invalid padding: ' + str(padding)) return padding @@ -3595,7 +3613,7 @@ def conv1d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) kernel_shape = kernel.get_shape().as_list() if padding == 'causal': @@ -3647,7 +3665,7 @@ def conv2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3694,7 +3712,7 @@ def conv2d_transpose(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if isinstance(output_shape, (tuple, list)): output_shape = array_ops.stack(output_shape) @@ -3753,16 +3771,18 @@ def separable_conv1d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv1d_input(x, data_format) padding = _preprocess_padding(padding) + if not isinstance(strides, tuple): + strides = tuple(strides) if tf_data_format == 'NHWC': spatial_start_dim = 1 - strides = (1, 1) + strides + (1,) + strides = (1,) + strides * 2 + (1,) else: spatial_start_dim = 2 - strides = (1, 1, 1) + strides + strides = (1, 1) + strides * 2 x = array_ops.expand_dims(x, spatial_start_dim) depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0) pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0) @@ -3815,10 +3835,12 @@ def separable_conv2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) + if not isinstance(strides, tuple): + strides = tuple(strides) if tf_data_format == 'NHWC': strides = (1,) + strides + (1,) else: @@ -3864,7 +3886,7 @@ def depthwise_conv2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3914,7 +3936,7 @@ def conv3d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv3d_input(x, data_format) padding = _preprocess_padding(padding) @@ -3960,7 +3982,7 @@ def conv3d_transpose(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) if isinstance(output_shape, (tuple, list)): output_shape = array_ops.stack(output_shape) @@ -4019,7 +4041,7 @@ def pool2d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv2d_input(x, data_format) padding = _preprocess_padding(padding) @@ -4037,7 +4059,7 @@ def pool2d(x, x = nn.avg_pool( x, pool_size, strides, padding=padding, data_format=tf_data_format) else: - raise ValueError('Invalid pooling mode:', pool_mode) + raise ValueError('Invalid pooling mode: ' + str(pool_mode)) if data_format == 'channels_first' and tf_data_format == 'NHWC': x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW @@ -4072,7 +4094,7 @@ def pool3d(x, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) x, tf_data_format = _preprocess_conv3d_input(x, data_format) padding = _preprocess_padding(padding) @@ -4090,7 +4112,7 @@ def pool3d(x, x = nn.avg_pool3d( x, pool_size, strides, padding=padding, data_format=tf_data_format) else: - raise ValueError('Invalid pooling mode:', pool_mode) + raise ValueError('Invalid pooling mode: ' + str(pool_mode)) if data_format == 'channels_first' and tf_data_format == 'NDHWC': x = array_ops.transpose(x, (0, 4, 1, 2, 3)) @@ -4121,7 +4143,7 @@ def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) stride = strides[0] kernel_shape = int_shape(kernel) @@ -4177,7 +4199,7 @@ def local_conv2d(inputs, if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) stride_row, stride_col = strides output_row, output_col = output_shape @@ -4230,7 +4252,7 @@ def bias_add(x, bias, data_format=None): if data_format is None: data_format = image_data_format() if data_format not in {'channels_first', 'channels_last'}: - raise ValueError('Unknown data_format ' + str(data_format)) + raise ValueError('Unknown data_format: ' + str(data_format)) bias_shape = int_shape(bias) if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1: raise ValueError( diff --git a/tensorflow/python/keras/_impl/keras/callbacks.py b/tensorflow/python/keras/_impl/keras/callbacks.py index b29bc39232546ada8f73cd351ac2b9c3eccfa6da..deb1e8867dba3d52816ebda02bd9a3bf2ec7bc09 100644 --- a/tensorflow/python/keras/_impl/keras/callbacks.py +++ b/tensorflow/python/keras/_impl/keras/callbacks.py @@ -164,7 +164,7 @@ class CallbackList(object): class Callback(object): """Abstract base class used to build new callbacks. - # Properties + Attributes: params: dict. Training parameters (eg. verbosity, batch size, number of epochs...). model: instance of `keras.models.Model`. @@ -222,8 +222,18 @@ class BaseLogger(Callback): """Callback that accumulates epoch averages of metrics. This callback is automatically applied to every Keras model. + + Arguments: + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over an epoch. + Metrics in this list will be logged as-is in `on_epoch_end`. + All others will be averaged in `on_epoch_end`. """ + def __init__(self, stateful_metrics=None): + super(BaseLogger, self).__init__() + self.stateful_metrics = set(stateful_metrics or []) + def on_epoch_begin(self, epoch, logs=None): self.seen = 0 self.totals = {} @@ -234,17 +244,23 @@ class BaseLogger(Callback): self.seen += batch_size for k, v in logs.items(): - if k in self.totals: - self.totals[k] += v * batch_size + if k in self.stateful_metrics: + self.totals[k] = v else: - self.totals[k] = v * batch_size + if k in self.totals: + self.totals[k] += v * batch_size + else: + self.totals[k] = v * batch_size def on_epoch_end(self, epoch, logs=None): if logs is not None: for k in self.params['metrics']: if k in self.totals: # Make value available to next callbacks. - logs[k] = self.totals[k] / self.seen + if k in self.stateful_metrics: + logs[k] = self.totals[k] + else: + logs[k] = self.totals[k] / self.seen @tf_export('keras.callbacks.TerminateOnNaN') @@ -272,12 +288,16 @@ class ProgbarLogger(Callback): count_mode: One of "steps" or "samples". Whether the progress bar should count samples seen or steps (batches) seen. + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over an epoch. + Metrics in this list will be logged as-is. + All others will be averaged over time (e.g. loss, etc). Raises: ValueError: In case of invalid `count_mode`. """ - def __init__(self, count_mode='samples'): + def __init__(self, count_mode='samples', stateful_metrics=None): super(ProgbarLogger, self).__init__() if count_mode == 'samples': self.use_steps = False @@ -285,6 +305,7 @@ class ProgbarLogger(Callback): self.use_steps = True else: raise ValueError('Unknown `count_mode`: ' + str(count_mode)) + self.stateful_metrics = set(stateful_metrics or []) def on_train_begin(self, logs=None): self.verbose = self.params['verbose'] @@ -298,7 +319,10 @@ class ProgbarLogger(Callback): else: target = self.params['samples'] self.target = target - self.progbar = Progbar(target=self.target, verbose=self.verbose) + self.progbar = Progbar( + target=self.target, + verbose=self.verbose, + stateful_metrics=self.stateful_metrics) self.seen = 0 def on_batch_begin(self, batch, logs=None): @@ -328,7 +352,7 @@ class ProgbarLogger(Callback): if k in logs: self.log_values.append((k, logs[k])) if self.verbose: - self.progbar.update(self.seen, self.log_values, force=True) + self.progbar.update(self.seen, self.log_values) @tf_export('keras.callbacks.History') @@ -754,16 +778,24 @@ class TensorBoard(Callback): while i < val_size: step = min(self.batch_size, val_size - i) batch_val = [] - batch_val.append(val_data[0][i:i + step]) - batch_val.append(val_data[1][i:i + step]) - batch_val.append(val_data[2][i:i + step]) + batch_val.append(val_data[0][i:i + step] + if val_data[0] is not None else None) + batch_val.append(val_data[1][i:i + step] + if val_data[1] is not None else None) + batch_val.append(val_data[2][i:i + step] + if val_data[2] is not None else None) if self.model.uses_learning_phase: # do not slice the learning phase - batch_val = [x[i:i + step] for x in val_data[:-1]] + batch_val = [x[i:i + step] if x is not None else None + for x in val_data[:-1]] batch_val.append(val_data[-1]) else: - batch_val = [x[i:i + step] for x in val_data] - feed_dict = dict(zip(tensors, batch_val)) + batch_val = [x[i:i + step] if x is not None else None + for x in val_data] + feed_dict = {} + for key, val in zip(tensors, batch_val): + if val is not None: + feed_dict[key] = val result = self.sess.run([self.merged], feed_dict=feed_dict) summary_str = result[0] self.writer.add_summary(summary_str, epoch) diff --git a/tensorflow/python/keras/_impl/keras/constraints.py b/tensorflow/python/keras/_impl/keras/constraints.py index ab62d575e34c1a43d4b02bf5e4ce7962229ce15a..271fbbb63d3dfd50507837e190860d48315a14f2 100644 --- a/tensorflow/python/keras/_impl/keras/constraints.py +++ b/tensorflow/python/keras/_impl/keras/constraints.py @@ -202,4 +202,5 @@ def get(identifier): elif callable(identifier): return identifier else: - raise ValueError('Could not interpret constraint identifier:', identifier) + raise ValueError('Could not interpret constraint identifier: ' + + str(identifier)) diff --git a/tensorflow/python/keras/_impl/keras/datasets/cifar.py b/tensorflow/python/keras/_impl/keras/datasets/cifar.py index 7ada3340a59e114d73095068ec476da5973b67fb..02344897f774723d0ad690ae641952cb63022bdf 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/cifar.py +++ b/tensorflow/python/keras/_impl/keras/datasets/cifar.py @@ -34,17 +34,16 @@ def load_batch(fpath, label_key='labels'): Returns: A tuple `(data, labels)`. """ - f = open(fpath, 'rb') - if sys.version_info < (3,): - d = cPickle.load(f) - else: - d = cPickle.load(f, encoding='bytes') - # decode utf8 - d_decoded = {} - for k, v in d.items(): - d_decoded[k.decode('utf8')] = v - d = d_decoded - f.close() + with open(fpath, 'rb') as f: + if sys.version_info < (3,): + d = cPickle.load(f) + else: + d = cPickle.load(f, encoding='bytes') + # decode utf8 + d_decoded = {} + for k, v in d.items(): + d_decoded[k.decode('utf8')] = v + d = d_decoded data = d['data'] labels = d[label_key] diff --git a/tensorflow/python/keras/_impl/keras/datasets/imdb.py b/tensorflow/python/keras/_impl/keras/datasets/imdb.py index e2dddf7730f2a922b09de4dadb4dd282b05caf21..7467bb24646227705972262381aa5cf1de809f1c 100644 --- a/tensorflow/python/keras/_impl/keras/datasets/imdb.py +++ b/tensorflow/python/keras/_impl/keras/datasets/imdb.py @@ -144,7 +144,5 @@ def get_word_index(path='imdb_word_index.json'): path, origin='https://s3.amazonaws.com/text-datasets/imdb_word_index.json', file_hash='bfafd718b763782e994055a2d397834f') - f = open(path) - data = json.load(f) - f.close() - return data + with open(path) as f: + return json.load(f) diff --git a/tensorflow/python/keras/_impl/keras/engine/__init__.py b/tensorflow/python/keras/_impl/keras/engine/__init__.py index 31f624f9af65cac60b6466d4eb5753cbdee984c6..1bc533ab8f7ba37948d82bc69fe1c9bfe00d6834 100644 --- a/tensorflow/python/keras/_impl/keras/engine/__init__.py +++ b/tensorflow/python/keras/_impl/keras/engine/__init__.py @@ -18,13 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.keras._impl.keras.engine.topology import get_source_inputs -from tensorflow.python.keras._impl.keras.engine.topology import Input -from tensorflow.python.keras._impl.keras.engine.topology import InputLayer -from tensorflow.python.keras._impl.keras.engine.topology import InputSpec -from tensorflow.python.keras._impl.keras.engine.topology import Layer +from tensorflow.python.keras._impl.keras.engine.base_layer import InputSpec +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.input_layer import Input +from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer +from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs +from tensorflow.python.keras._impl.keras.engine.network import Network from tensorflow.python.keras._impl.keras.engine.training import Model - - -# Note: topology.Node is an internal class, -# it isn't meant to be used by Keras users. diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..142325041bf4d2f8a564adcf867f3b719435f0ba --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py @@ -0,0 +1,504 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Base layer code (`Layer`). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from six.moves import zip # pylint: disable=redefined-builtin + +from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import constraints +from tensorflow.python.keras._impl.keras import initializers +from tensorflow.python.keras._impl.keras import regularizers +from tensorflow.python.keras._impl.keras.utils import generic_utils +from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + + +# pylint: disable=invalid-name +InputSpec = tf_base_layers.InputSpec +Node = tf_base_layers.Node +TFBaseLayer = tf_base_layers.Layer +# pylint: enable=invalid-name + + +@tf_export('keras.layers.Layer') +class Layer(tf_base_layers.Layer): + """Abstract base layer class. + + # Properties + name: String, must be unique within a model. + input_spec: List of InputSpec class instances + each entry describes one required input: + - ndim + - dtype + A layer with `n` input tensors must have + an `input_spec` of length `n`. + trainable: Boolean, whether the layer weights + will be updated during training. + uses_learning_phase: Whether any operation + of the layer uses `K.in_training_phase()` + or `K.in_test_phase()`. + input_shape: Shape tuple. Provided for convenience, + but note that there may be cases in which this + attribute is ill-defined (e.g. a shared layer + with multiple input shapes), in which case + requesting `input_shape` will raise an Exception. + Prefer using `layer.get_input_shape_for(input_shape)`, + or `layer.get_input_shape_at(node_index)`. + output_shape: Shape tuple. See above. + inbound_nodes: List of nodes. + outbound_nodes: List of nodes. + input, output: Input/output tensor(s). Note that if the layer is used + more than once (shared layer), this is ill-defined + and will raise an exception. In such cases, use + `layer.get_input_at(node_index)`. + input_mask, output_mask: Same as above, for masks. + trainable_weights: List of variables. + non_trainable_weights: List of variables. + weights: The concatenation of the lists trainable_weights and + non_trainable_weights (in this order). + + # Methods + call(x, mask=None): Where the layer's logic lives. + __call__(x, mask=None): Wrapper around the layer logic (`call`). + If x is a Keras tensor: + - Connect current layer with last layer from tensor: + `self._add_inbound_node(last_layer)` + - Add layer to tensor history + If layer is not built: + - Build from inputs shape + get_weights() + set_weights(weights) + get_config() + count_params() + compute_output_shape(input_shape) + compute_mask(x, mask) + get_input_at(node_index) + get_output_at(node_index) + get_input_shape_at(node_index) + get_output_shape_at(node_index) + get_input_mask_at(node_index) + get_output_mask_at(node_index) + + # Class Methods + from_config(config) + + # Internal methods: + build(input_shape) + _add_inbound_node(layer, index=0) + """ + + def __init__(self, **kwargs): + # These properties should be set by the user via keyword arguments. + # note that 'dtype', 'input_shape' and 'batch_input_shape' + # are only applicable to input layers: do not pass these keywords + # to non-input layers. + allowed_kwargs = { + 'activity_regularizer', + 'input_shape', + 'batch_input_shape', + 'batch_size', + 'dtype', + 'name', + 'trainable', + 'weights', + } + # Validate optional keyword arguments. + for kwarg in kwargs: + if kwarg not in allowed_kwargs: + raise TypeError('Keyword argument not understood:', kwarg) + + # Get layer name. + name = kwargs.get('name') + + # Get `trainable` status. + trainable = kwargs.get('trainable', True) + + # Get `dtype`. + dtype = kwargs.get('dtype') + if dtype is None: + dtype = K.floatx() + + # Call super, which will set all properties common to Keras layers + # and core TF layers. + super(Layer, self).__init__( + name=name, dtype=dtype, trainable=trainable, + activity_regularizer=kwargs.get('activity_regularizer')) + + # Add properties that are Keras-only for now. + self.supports_masking = False + + # Manage input shape information if passed. + if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: + # In this case we will later create an input layer + # to insert before the current layer + if 'batch_input_shape' in kwargs: + batch_input_shape = tuple(kwargs['batch_input_shape']) + elif 'input_shape' in kwargs: + if 'batch_size' in kwargs: + batch_size = kwargs['batch_size'] + else: + batch_size = None + batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) + self._batch_input_shape = batch_input_shape + + # Manage initial weight values if passed. + if 'weights' in kwargs: + self._initial_weights = kwargs['weights'] + else: + self._initial_weights = None + + def add_weight(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + constraint=None): + """Adds a weight variable to the layer. + + Arguments: + name: String, the name for the weight variable. + shape: The shape tuple of the weight. + dtype: The dtype of the weight. + initializer: An Initializer instance (callable). + regularizer: An optional Regularizer instance. + trainable: A boolean, whether the weight should + be trained via backprop or not (assuming + that the layer itself is also trainable). + constraint: An optional Constraint instance. + + Returns: + The created weight variable. + """ + if dtype is None: + dtype = K.floatx() + weight = self.add_variable(name, shape, + dtype=dtype, + initializer=initializers.get(initializer), + regularizer=regularizers.get(regularizer), + constraint=constraints.get(constraint), + trainable=trainable) + return weight + + def call(self, inputs, **kwargs): # pylint: disable=unused-argument + """This is where the layer's logic lives. + + Arguments: + inputs: Input tensor, or list/tuple of input tensors. + **kwargs: Additional keyword arguments. + + Returns: + A tensor or list/tuple of tensors. + """ + return inputs + + def __call__(self, inputs, **kwargs): + """Wrapper around self.call(), for handling internal references. + + If a Keras tensor is passed: + - We call self._add_inbound_node(). + - If necessary, we `build` the layer to match + the shape of the input(s). + - We update the _keras_history of the output tensor(s) + with the current layer. + This is done as part of _add_inbound_node(). + + Arguments: + inputs: Can be a tensor or list/tuple of tensors. + **kwargs: Additional keyword arguments to be passed to `call()`. + + Returns: + Output of the layer's `call` method. + + Raises: + ValueError: in case the layer is missing shape information + for its `build` call. + """ + # Actually call the layer (optionally building it). + output = super(Layer, self).__call__(inputs, **kwargs) + if context.in_eager_mode(): + return output + + # Un-built subclassed network: build it + if hasattr(self, '_set_inputs') and not self.inputs: + self._set_inputs(inputs, training=kwargs.get('training')) + + # Update learning phase info. + output_tensors = generic_utils.to_list(output) + uses_lp = any( + [getattr(x, '_uses_learning_phase', False) + for x in generic_utils.to_list(inputs)]) + uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp + for i in range(len(output_tensors)): + output_tensors[i]._uses_learning_phase = getattr( + output_tensors[i], '_uses_learning_phase', False) or uses_lp + + # Optionally load weight values that were specified at layer instantiation. + if hasattr(self, '_initial_weights') and self._initial_weights is not None: + self.set_weights(self._initial_weights) + del self._initial_weights + return output + + def compute_output_shape(self, input_shape): + """Computes the output shape of the layer. + + Assumes that the layer will be built + to match that input shape provided. + + Arguments: + input_shape: Shape tuple (tuple of integers) + or list of shape tuples (one per output tensor of the layer). + Shape tuples can include None for free dimensions, + instead of an integer. + + Returns: + An input shape tuple. + """ + logging.warning( + 'All custom layers should implement the ' + '`compute_output_shape` method. This layer (' + self.name + ') ' + 'is relying on the base `Layer.compute_output_shape` implementation, ' + 'which will start raising a `NotImplementedError` ' + 'as of July 1st, 2018.') + return input_shape + + def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument + """Computes an output mask tensor. + + Arguments: + inputs: Tensor or list of tensors. + mask: Tensor or list of tensors. + + Returns: + None or a tensor (or list of tensors, + one per output tensor of the layer). + """ + if not self.supports_masking: + if mask is not None: + if isinstance(mask, list): + if any(m is not None for m in mask): + raise TypeError('Layer ' + self.name + ' does not support masking, ' + 'but was passed an input_mask: ' + str(mask)) + else: + raise TypeError('Layer ' + self.name + ' does not support masking, ' + 'but was passed an input_mask: ' + str(mask)) + # masking not explicitly supported: return None as mask + return None + # if masking is explicitly supported, by default + # carry over the input mask + return mask + + def get_input_mask_at(self, node_index): + """Retrieves the input mask tensor(s) of a layer at a given node. + + Arguments: + node_index: Integer, index of the node + from which to retrieve the attribute. + E.g. `node_index=0` will correspond to the + first time the layer was called. + + Returns: + A mask tensor + (or list of tensors if the layer has multiple inputs). + """ + inputs = self.get_input_at(node_index) + if isinstance(inputs, list): + return [getattr(x, '_keras_mask', None) for x in inputs] + else: + return getattr(inputs, '_keras_mask', None) + + def get_output_mask_at(self, node_index): + """Retrieves the output mask tensor(s) of a layer at a given node. + + Arguments: + node_index: Integer, index of the node + from which to retrieve the attribute. + E.g. `node_index=0` will correspond to the + first time the layer was called. + + Returns: + A mask tensor + (or list of tensors if the layer has multiple outputs). + """ + output = self.get_output_at(node_index) + if isinstance(output, list): + return [getattr(x, '_keras_mask', None) for x in output] + else: + return getattr(output, '_keras_mask', None) + + @property + def input_mask(self): + """Retrieves the input mask tensor(s) of a layer. + + Only applicable if the layer has exactly one inbound node, + i.e. if it is connected to one incoming layer. + + Returns: + Input mask tensor (potentially None) or list of input + mask tensors. + + Raises: + AttributeError: if the layer is connected to + more than one incoming layers. + """ + inputs = self.input + if isinstance(inputs, list): + return [getattr(x, '_keras_mask', None) for x in inputs] + else: + return getattr(inputs, '_keras_mask', None) + + @property + def output_mask(self): + """Retrieves the output mask tensor(s) of a layer. + + Only applicable if the layer has exactly one inbound node, + i.e. if it is connected to one incoming layer. + + Returns: + Output mask tensor (potentially None) or list of output + mask tensors. + + Raises: + AttributeError: if the layer is connected to + more than one incoming layers. + """ + output = self.output + if isinstance(output, list): + return [getattr(x, '_keras_mask', None) for x in output] + else: + return getattr(output, '_keras_mask', None) + + def set_weights(self, weights): + """Sets the weights of the layer, from Numpy arrays. + + Arguments: + weights: a list of Numpy arrays. The number + of arrays and their shape must match + number of the dimensions of the weights + of the layer (i.e. it should match the + output of `get_weights`). + + Raises: + ValueError: If the provided weights list does not match the + layer's specifications. + """ + params = self.weights + if len(params) != len(weights): + raise ValueError('You called `set_weights(weights)` on layer "' + + self.name + '" with a weight list of length ' + + str(len(weights)) + ', but the layer was expecting ' + + str(len(params)) + ' weights. Provided weights: ' + + str(weights)[:50] + '...') + if not params: + return + weight_value_tuples = [] + param_values = K.batch_get_value(params) + for pv, p, w in zip(param_values, params, weights): + if pv.shape != w.shape: + raise ValueError('Layer weight shape ' + str(pv.shape) + + ' not compatible with ' + 'provided weight shape ' + str(w.shape)) + weight_value_tuples.append((p, w)) + K.batch_set_value(weight_value_tuples) + + def get_weights(self): + """Returns the current weights of the layer. + + Returns: + Weights values as a list of numpy arrays. + """ + params = self.weights + return K.batch_get_value(params) + + def get_config(self): + """Returns the config of the layer. + + A layer config is a Python dictionary (serializable) + containing the configuration of a layer. + The same layer can be reinstantiated later + (without its trained weights) from this configuration. + + The config of a layer does not include connectivity + information, nor the layer class name. These are handled + by `Network` (one layer of abstraction above). + + Returns: + Python dictionary. + """ + config = {'name': self.name, 'trainable': self.trainable} + if hasattr(self, '_batch_input_shape'): + config['batch_input_shape'] = self._batch_input_shape + if hasattr(self, 'dtype'): + config['dtype'] = self.dtype + return config + + @classmethod + def from_config(cls, config): + """Creates a layer from its config. + + This method is the reverse of `get_config`, + capable of instantiating the same layer from the config + dictionary. It does not handle layer connectivity + (handled by Network), nor weights (handled by `set_weights`). + + Arguments: + config: A Python dictionary, typically the + output of get_config. + + Returns: + A layer instance. + """ + return cls(**config) + + @tf_base_layers.Layer.activity_regularizer.setter + def activity_regularizer(self, activity_regularizer): + self._activity_regularizer = activity_regularizer + + +def shape_type_conversion(fn): + """Decorator that handles tuple/TensorShape conversion. + + Used in `compute_output_shape` and `build`. + + Arguments: + fn: function to wrap. + + Returns: + Wrapped function. + """ + + def wrapper(instance, input_shape): + if input_shape is not None: + if isinstance(input_shape, list): + input_shape = [ + tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] + else: + input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) + output_shape = fn(instance, input_shape) + if output_shape is not None: + if isinstance(output_shape, list): + return [tensor_shape.TensorShape(x) for x in output_shape] + return tensor_shape.TensorShape(output_shape) + + return wrapper diff --git a/tensorflow/python/keras/_impl/keras/engine/input_layer.py b/tensorflow/python/keras/_impl/keras/engine/input_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..8f9ea6f7a40e49ec45dfaeb14f807cd9c7db65c9 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/input_layer.py @@ -0,0 +1,230 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Input layer code (`Input` and `InputLayer`). +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.eager import context +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.ops import array_ops +from tensorflow.python.util.tf_export import tf_export + + +class InputLayer(base_layer.Layer): + """Layer to be used as an entry point into a Network (a graph of layers). + + It can either wrap an existing tensor (pass an `input_tensor` argument) + or create its a placeholder tensor (pass arguments `input_shape`, and + optionally, `dtype`). + + It is generally recommend to use the functional layer API via `Input`, + (which creates an `InputLayer`) without directly using `InputLayer`. + + Arguments: + input_shape: Shape tuple (not including the batch axis), or `TensorShape` + instance (not including the batch axis). + batch_size: Optional input batch size (integer or None). + dtype: Datatype of the input. + input_tensor: Optional tensor to use as layer input + instead of creating a placeholder. + sparse: Boolean, whether the placeholder created + is meant to be sparse. + name: Name of the layer (string). + """ + + def __init__(self, + input_shape=None, + batch_size=None, + dtype=None, + input_tensor=None, + sparse=False, + name=None, + **kwargs): + if 'batch_input_shape' in kwargs: + batch_input_shape = kwargs.pop('batch_input_shape') + if input_shape and batch_input_shape: + raise ValueError('Only provide the input_shape OR ' + 'batch_input_shape argument to ' + 'InputLayer, not both at the same time.') + batch_size = batch_input_shape[0] + input_shape = batch_input_shape[1:] + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + + if not name: + prefix = 'input' + name = prefix + '_' + str(K.get_uid(prefix)) + + if not dtype: + if input_tensor is None: + dtype = K.floatx() + else: + dtype = K.dtype(input_tensor) + super(InputLayer, self).__init__(dtype=dtype, name=name) + self.built = True + self.sparse = sparse + self.batch_size = batch_size + + if isinstance(input_shape, tensor_shape.TensorShape): + input_shape = tuple(input_shape.as_list()) + + if input_tensor is None: + if input_shape is not None: + batch_input_shape = (batch_size,) + tuple(input_shape) + else: + batch_input_shape = None + + if context.in_eager_mode(): + # In eager mode, create a temporary placeholder to call the layer on. + input_tensor = tf_base_layers._DeferredTensor( # pylint: disable=protected-access + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + # In graph mode, create a graph placeholder to call the layer on. + if sparse: + input_tensor = array_ops.sparse_placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + else: + input_tensor = array_ops.placeholder( + shape=batch_input_shape, + dtype=dtype, + name=self.name) + + # For compatibility with Keras API. + self.is_placeholder = True + self._batch_input_shape = batch_input_shape + else: + # For compatibility with Keras API. + self.is_placeholder = False + self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) + + # Create an input node to add to self.outbound_node + # and set output_tensors' _keras_history. + input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access + tf_base_layers.Node( + self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=[input_tensor], + output_tensors=[input_tensor]) + + def get_config(self): + config = { + 'batch_input_shape': self._batch_input_shape, + 'dtype': self.dtype, + 'sparse': self.sparse, + 'name': self.name + } + return config + + +@tf_export('keras.layers.Input', 'keras.Input') +def Input( # pylint: disable=invalid-name + shape=None, + batch_size=None, + name=None, + dtype=None, + sparse=False, + tensor=None, + **kwargs): + """`Input()` is used to instantiate a Keras tensor. + + A Keras tensor is a tensor object from the underlying backend + (Theano or TensorFlow), which we augment with certain + attributes that allow us to build a Keras model + just by knowing the inputs and outputs of the model. + + For instance, if a, b and c are Keras tensors, + it becomes possible to do: + `model = Model(input=[a, b], output=c)` + + The added Keras attribute is: + `_keras_history`: Last layer applied to the tensor. + the entire layer graph is retrievable from that layer, + recursively. + + Arguments: + shape: A shape tuple (integers), not including the batch size. + For instance, `shape=(32,)` indicates that the expected input + will be batches of 32-dimensional vectors. + batch_size: optional static batch size (integer). + name: An optional name string for the layer. + Should be unique in a model (do not reuse the same name twice). + It will be autogenerated if it isn't provided. + dtype: The data type expected by the input, as a string + (`float32`, `float64`, `int32`...) + sparse: A boolean specifying whether the placeholder + to be created is sparse. + tensor: Optional existing tensor to wrap into the `Input` layer. + If set, the layer will not create a placeholder tensor. + **kwargs: deprecated arguments support. + + Returns: + A tensor. + + Example: + + ```python + # this is a logistic regression in Keras + x = Input(shape=(32,)) + y = Dense(16, activation='softmax')(x) + model = Model(x, y) + ``` + + Raises: + ValueError: in case of invalid arguments. + """ + if 'batch_shape' in kwargs: + batch_shape = kwargs.pop('batch_shape') + if shape and batch_shape: + raise ValueError('Only provide the shape OR ' + 'batch_shape argument to ' + 'Input, not both at the same time.') + batch_size = batch_shape[0] + shape = batch_shape[1:] + if kwargs: + raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) + + if dtype is None: + dtype = K.floatx() + if not shape and tensor is None: + raise ValueError('Please provide to Input either a `shape`' + ' or a `tensor` argument. Note that ' + '`shape` does not include the batch ' + 'dimension.') + input_layer = InputLayer( + input_shape=shape, + batch_size=batch_size, + name=name, + dtype=dtype, + sparse=sparse, + input_tensor=tensor) + # Return tensor including `_keras_history`. + # Note that in this case train_output and test_output are the same pointer. + outputs = input_layer._inbound_nodes[0].output_tensors + if len(outputs) == 1: + return outputs[0] + else: + return outputs diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py new file mode 100644 index 0000000000000000000000000000000000000000..453cc8f8b7268376f48f07f5c8cf788a0aa10ddc --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/network.py @@ -0,0 +1,1501 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""A `Network` is way to compose layers: the topological form of a `Model`. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import json +import os + +import numpy as np +from six.moves import zip # pylint: disable=redefined-builtin + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras._impl.keras.engine import saving +from tensorflow.python.keras._impl.keras.utils import generic_utils +from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary +from tensorflow.python.layers import base as tf_base_layers +from tensorflow.python.layers import utils as tf_layers_util +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect + + +# pylint: disable=g-import-not-at-top +try: + import h5py +except ImportError: + h5py = None + +try: + import yaml +except ImportError: + yaml = None +# pylint: enable=g-import-not-at-top + + +class Network(base_layer.Layer): + """A `Network` is a composition of layers. + + It is the topological form of a "model". A `Model` + is simply a `Network` with added training routines. + """ + + def __init__(self, *args, **kwargs): # pylint: disable=super-init-not-called + # Signature detection + if (len(args) == 2 or + len(args) == 1 and 'outputs' in kwargs or + 'inputs' in kwargs and 'outputs' in kwargs): + # Graph network + self._init_graph_network(*args, **kwargs) + else: + # Subclassed network + self._init_subclassed_network(**kwargs) + + def _base_init(self, name=None): + # The following are implemented as property functions: + # self.trainable_weights + # self.non_trainable_weights + # self.input_spec + # self.losses + # self.updates + + self._init_set_name(name) + self._activity_regularizer = None + # This acts just like the `trainable` attribute of any layer instance. + # It does not affect users of the underlying layers, only users of the + # Network instance. + self.trainable = True + self._is_compiled = False + self._expects_training_arg = False + + self.supports_masking = False + self.optimizer = None + + # Private attributes to implement compatibility with Layer. + self._updates = [] # Used in symbolic mode only. + self._losses = [] # Used in symbolic mode only. + self._scope = None # Never used. + self._reuse = None # Never used. + if context.in_eager_mode: + self._graph = None + else: + self._graph = ops.get_default_graph() # Used in symbolic mode only. + # A Network does not create weights of its own, thus has no dtype. + self._dtype = None + + # All layers in order of horizontal graph traversal. + # Entries are unique. Includes input and output layers. + self._layers = [] + + # Used in symbolic mode only, only in conjonction with graph-networks + self._outbound_nodes = [] + self._inbound_nodes = [] + + def _init_graph_network(self, inputs, outputs, name=None): + # Normalize and set self.inputs, self.outputs. + if isinstance(inputs, (list, tuple)): + self.inputs = list(inputs) # Tensor or list of tensors. + else: + self.inputs = [inputs] + if isinstance(outputs, (list, tuple)): + self.outputs = list(outputs) + else: + self.outputs = [outputs] + + # User-prodived argument validation. + if context.in_eager_mode(): + # Check that all inputs/outputs are DeferredTensors. + for tensor in self.inputs: + if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access + raise TypeError('When eager execution is enabled, ' + 'inputs must come from a call to ' + '`tf.keras.Input` (called after ' + 'tfe.enable_eager_execution()). ' + 'Received invalid input: ' + str(tensor)) + for tensor in self.outputs: + if not isinstance(tensor, tf_base_layers._DeferredTensor): # pylint: disable=protected-access + raise TypeError('When eager execution is enabled, ' + 'outputs must come from a call to ' + 'a layer (called after ' + 'tfe.enable_eager_execution()). ' + 'Received invalid output: ' + str(tensor)) + # Check for redundancy in inputs. + if len(set(self.inputs)) != len(self.inputs): + raise ValueError('The list of inputs passed to the model ' + 'is redundant. ' + 'All inputs should only appear once.' + ' Found: ' + str(self.inputs)) + for x in self.inputs: + # Check that x has appropriate `_keras_history` metadata. + if not hasattr(x, '_keras_history'): + cls_name = self.__class__.__name__ + raise ValueError('Input tensors to a ' + cls_name + ' ' + + 'must come from `tf.layers.Input`. ' + 'Received: ' + str(x) + + ' (missing previous layer metadata).') + # Check that x is an input tensor. + # pylint: disable=protected-access + layer, node_index, tensor_index = x._keras_history + if len(layer._inbound_nodes) > 1 or ( + layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): + cls_name = self.__class__.__name__ + logging.warning(cls_name + ' inputs must come from ' + '`tf.layers.Input` (thus holding past layer metadata), ' + 'they cannot be the output of ' + 'a previous non-Input layer. ' + 'Here, a tensor specified as ' + 'input to "' + self.name + '" was not an Input tensor, ' + 'it was generated by layer ' + layer.name + '.\n' + 'Note that input tensors are ' + 'instantiated via `tensor = tf.layers.Input(shape)`.\n' + 'The tensor that caused the issue was: ' + str(x.name)) + for x in self.outputs: + if not hasattr(x, '_keras_history'): + cls_name = self.__class__.__name__ + raise ValueError('Output tensors to a ' + cls_name + ' must be ' + 'the output of a TensorFlow `Layer` ' + '(thus holding past layer metadata). Found: ' + str(x)) + + self._base_init(name=name) + self._compute_previous_mask = ( + 'mask' in tf_inspect.getargspec(self.call).args or + hasattr(self, 'compute_mask')) + # A Network does not create weights of its own, thus it is already + # built. + self.built = True + self._is_graph_network = True + + # # List of initial layers (1 to 1 mapping with self.inputs, + # # hence the same layer might appear twice) + # self._input_layers = [] + # self._input_layers_node_indices = [] + # self._input_layers_tensor_indices = [] + # # list of layers (1 to 1 mapping with self.inputs, + # # hence the same layer might appear twice) + # self._output_layers = [] + # self._output_layers_node_indices = [] + # self._output_layers_tensor_indices = [] + + self._input_layers = [] + self._output_layers = [] + self._input_coordinates = [] + self._output_coordinates = [] + + # This is for performance optimization when calling the Network on new + # inputs. Every time the Network is called on a set on input tensors, + # we compute the output tensors, output masks and output shapes in one pass, + # then cache them here. When any of these outputs is queried later, we + # retrieve it from there instead of recomputing it. + self._output_mask_cache = {} + self._output_tensor_cache = {} + self._output_shape_cache = {} + + # Build self._output_layers: + for x in self.outputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + self._output_layers.append(layer) + self._output_coordinates.append((layer, node_index, tensor_index)) + + # Build self._input_layers: + for x in self.inputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + # It's supposed to be an input layer, so only one node + # and one tensor output. + assert node_index == 0 + assert tensor_index == 0 + self._input_layers.append(layer) + self._input_coordinates.append((layer, node_index, tensor_index)) + + # Keep track of the network's nodes and layers. + nodes, nodes_by_depth, layers, layers_by_depth = _map_graph_network( + self.inputs, self.outputs) + self._network_nodes = nodes + self._nodes_by_depth = nodes_by_depth + self._layers = layers + self._layers_by_depth = layers_by_depth + + # Create the node linking internal inputs to internal outputs. + tf_base_layers.Node( + outbound_layer=self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=self.inputs, + output_tensors=self.outputs) + + # Fill in the output mask cache. + masks = [] + for x in self.inputs: + mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access + masks.append(mask) + mask_cache_key = (tf_layers_util.object_list_uid(self.inputs) + '_' + + tf_layers_util.object_list_uid(masks)) + masks = [] + for x in self.outputs: + mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access + masks.append(mask) + if len(masks) == 1: + mask = masks[0] + else: + mask = masks + self._output_mask_cache[mask_cache_key] = mask + + # Build self.input_names and self.output_names. + self.input_names = [] + self.output_names = [] + self._feed_input_names = [] + self._feed_inputs = [] + self._feed_input_shapes = [] + for i, layer in enumerate(self._input_layers): + self.input_names.append(layer.name) + if layer.is_placeholder: + self._feed_input_names.append(layer.name) + self._feed_input_shapes.append(K.int_shape(self.inputs[i])) + # layer.input gives an error in eager mode + if context.in_graph_mode(): + self._feed_inputs.append(layer.input) + for layer in self._output_layers: + self.output_names.append(layer.name) + + def _init_subclassed_network(self, name=None): + self._base_init(name=name) + self._is_graph_network = False + if 'training' in tf_inspect.getargspec(self.call).args: + self._expects_training_arg = True + else: + self._expects_training_arg = False + + self.outputs = None + self.inputs = None + self.built = False + + def __setattr__(self, name, value): + if isinstance(value, (tf_base_layers.Layer, Network)): + try: + is_graph_network = self._is_graph_network + except AttributeError: + raise RuntimeError('It looks like you are subclassing `Model` and you ' + 'forgot to call `super(YourClass, self).__init__()`.' + ' Always start with this line.') + if not is_graph_network: + if value not in self._layers: + self._layers.append(value) + super(Network, self).__setattr__(name, value) + + def add_variable(self, name, shape, dtype=None, initializer=None, + regularizer=None, trainable=True, constraint=None): + raise NotImplementedError('`add_variable` is not supported on Networks.') + + def add_loss(self, *args, **kwargs): + if context.in_eager_mode(): + raise NotImplementedError('`add_loss` is not supported on Networks ' + 'when eager execution is enabled.') + super(Network, self).add_loss(*args, **kwargs) + + @property + def uses_learning_phase(self): + return any( + [getattr(x, '_uses_learning_phase', False) for x in self.outputs]) + + @property + def stateful(self): + return any([(hasattr(layer, 'stateful') and layer.stateful) + for layer in self.layers]) + + def reset_states(self): + for layer in self.layers: + if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): + layer.reset_states() + + @property + def state_updates(self): + """Returns the `updates` from all layers that are stateful. + + This is useful for separating training updates and + state updates, e.g. when we need to update a layer's internal state + during prediction. + + Returns: + A list of update ops. + """ + state_updates = [] + for layer in self.layers: + if getattr(layer, 'stateful', False): + if hasattr(layer, 'updates'): + state_updates += layer.updates + return state_updates + + def get_weights(self): + """Retrieves the weights of the model. + + Returns: + A flat list of Numpy arrays. + """ + weights = [] + for layer in self.layers: + weights += layer.weights + return K.batch_get_value(weights) + + def set_weights(self, weights): + """Sets the weights of the model. + + Arguments: + weights: A list of Numpy arrays with shapes and types matching + the output of `model.get_weights()`. + """ + tuples = [] + for layer in self.layers: + num_param = len(layer.weights) + layer_weights = weights[:num_param] + for sw, w in zip(layer.weights, layer_weights): + tuples.append((sw, w)) + weights = weights[num_param:] + K.batch_set_value(tuples) + + def compute_mask(self, inputs, mask): + if not self._is_graph_network: + return None + + inputs = generic_utils.to_list(inputs) + if mask is None: + masks = [None for _ in range(len(inputs))] + else: + masks = generic_utils.to_list(mask) + cache_key = (tf_layers_util.object_list_uid(inputs) + + '_' + tf_layers_util.object_list_uid(masks)) + if cache_key in self._output_mask_cache: + return self._output_mask_cache[cache_key] + else: + _, output_masks = self._run_internal_graph(inputs, masks) + return output_masks + + @property + def layers(self): + return self._layers + + def get_layer(self, name=None, index=None): + """Retrieves a layer based on either its name (unique) or index. + + Indices are based on order of horizontal graph traversal (bottom-up). + + Arguments: + name: String, name of layer. + index: Integer, index of layer. + + Returns: + A layer instance. + + Raises: + ValueError: In case of invalid layer name or index. + """ + # TODO(fchollet): We could build a dictionary based on layer names + # since they are constant, but we have not done that yet. + if index is not None: + if len(self.layers) <= index: + raise ValueError('Was asked to retrieve layer at index ' + str(index) + + ' but model only has ' + str(len(self.layers)) + + ' layers.') + else: + return self.layers[index] + else: + if not name: + raise ValueError('Provide either a layer name or layer index.') + for layer in self.layers: + if layer.name == name: + return layer + raise ValueError('No such layer: ' + name) + + @property + def updates(self): + """Retrieve the network's updates. + + Will only include updates that are either + unconditional, or conditional on inputs to this model + (e.g. will not include updates that were created by layers of this model + outside of the model). + + Effectively, `network.updates` behaves like `layer.updates`. + + Concrete example: + + ```python + bn = keras.layers.BatchNormalization() + x1 = keras.layers.Input(shape=(10,)) + _ = bn(x1) # This creates 2 updates. + + x2 = keras.layers.Input(shape=(10,)) + y2 = bn(x2) # This creates 2 more updates. + + # The BN layer has now 4 updates. + self.assertEqual(len(bn.updates), 4) + + # Let's create a model from x2 to y2. + model = keras.models.Model(x2, y2) + + # The model does not list all updates from its underlying layers, + # but only the updates that are relevant to it. Updates created by layers + # outside of the model are discarded. + self.assertEqual(len(model.updates), 2) + + # If you keep calling the model, you append to its updates, just like + # what happens for a layer. + x3 = keras.layers.Input(shape=(10,)) + y3 = model(x3) + self.assertEqual(len(model.updates), 4) + + # But if you call the inner BN layer independently, you don't affect + # the model's updates. + x4 = keras.layers.Input(shape=(10,)) + _ = bn(x4) + self.assertEqual(len(model.updates), 4) + ``` + + Returns: + A list of update ops. + """ + if context.in_eager_mode(): + return [] + + if not self.trainable and not self.stateful: + return [] + + updates = [] + for layer in self.layers: + updates += layer.updates + + # `updates` might contain irrelevant updates, so it needs to be filtered + # with respect to inputs the model has been called on. + relevant_inputs = self.inputs or [] + for i in range(1, len(self._inbound_nodes)): + inputs = self.get_input_at(i) + if isinstance(inputs, list): + relevant_inputs += inputs + else: + relevant_inputs.append(inputs) + reachable = tf_layers_util.get_reachable_from_inputs(relevant_inputs, + updates) + relevant_conditional_updates = [x for x in updates if x in reachable] + unconditional_updates = [ + x for x in updates if x._unconditional_update] # pylint: disable=protected-access + # A layer could be used multiple times in a nested structure, + # so the updates list must be de-duped. + return list(set( + relevant_conditional_updates + unconditional_updates + self._updates)) + + @property + def losses(self): + """Retrieve the network's losses. + + Will only include losses that are either + unconditional, or conditional on inputs to this model + (e.g. will not include losses that depend on tensors + that aren't inputs to this model). + + Returns: + A list of loss tensors. + """ + losses = [] + for layer in self.layers: + losses += layer.losses + if context.in_eager_mode(): + return losses + + relevant_inputs = self.inputs or [] + for i in range(1, len(self._inbound_nodes)): + inputs = self.get_input_at(i) + if isinstance(inputs, list): + relevant_inputs += inputs + else: + relevant_inputs.append(inputs) + reachable = tf_layers_util.get_reachable_from_inputs(relevant_inputs, + losses) + relevant_conditional_losses = [x for x in losses if x in reachable] + unconditional_losses = [ + x for x in losses if x._unconditional_loss] # pylint: disable=protected-access + return list(set( + relevant_conditional_losses + unconditional_losses + self._losses)) + + @property + def trainable_weights(self): + if not self.trainable: + return [] + weights = [] + for layer in self.layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self.layers: + weights += layer.non_trainable_weights + if not self.trainable: + trainable_weights = [] + for layer in self.layers: + trainable_weights += layer.trainable_weights + return trainable_weights + weights + return weights + + @property + def input_spec(self): + """Gets the network's input specs. + + Returns: + A list of `InputSpec` instances (one per input to the model) + or a single instance if the model has only one input. + """ + # If not a graph network, can't assume anything. + if not self._is_graph_network: + return None + + specs = [] + for layer in self._input_layers: + if layer.input_spec is None: + specs.append(None) + else: + if not isinstance(layer.input_spec, list): + raise TypeError('Layer ' + layer.name + + ' has an input_spec attribute that ' + 'is not a list. We expect a list. ' + 'Found input_spec = ' + str(layer.input_spec)) + specs += layer.input_spec + if len(specs) == 1: + return specs[0] + return specs + + def call(self, inputs, training=None, mask=None): + """Call the model on new inputs. + + In this case `call` just reapplies + all ops in the graph to the new inputs + (e.g. build a new computational graph from the provided inputs). + + Arguments: + inputs: A tensor or list of tensors. + training: Boolean or boolean scalar tensor, indicating whether to run + the `Network` in training mode or inference mode. + mask: A mask or list of masks. A mask can be + either a tensor or None (no mask). + + Returns: + A tensor if there is a single output, or + a list of tensors if there are more than one outputs. + """ + inputs = nest.flatten(inputs) + if mask is None: + masks = [None for _ in range(len(inputs))] + else: + masks = nest.flatten(mask) + + if context.in_graph_mode(): + # Try to retrieve cached outputs if the layer has already been called + # on these exact inputs. + cache_key = (tf_layers_util.object_list_uid(inputs) + + '_' + tf_layers_util.object_list_uid(masks)) + if cache_key in self._output_tensor_cache: + # Cache hit. + return self._output_tensor_cache[cache_key] + # Actually apply the network graph to the new inputs. + outputs, _ = self._run_internal_graph(inputs, + training=training, + mask=masks) + return outputs + + def compute_output_shape(self, input_shape): + if not self._is_graph_network: + raise NotImplementedError + + if isinstance(input_shape, list): + input_shapes = [] + for shape in input_shape: + if shape is not None: + input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) + else: + input_shapes.append(None) + else: + if input_shape is not None: + input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] + else: + input_shapes = [None] + + if len(input_shapes) != len(self._input_layers): + raise ValueError('Invalid input_shape argument ' + str(input_shape) + + ': model has ' + str(len(self._input_layers)) + + ' tensor inputs.') + + cache_key = tf_layers_util.object_list_uid(input_shapes) + if cache_key not in self._output_shape_cache: + # Cache miss. We have to run the network graph manually (recursive calls + # to `compute_output_shape`). + layers_to_output_shapes = {} + for i in range(len(input_shapes)): + layer = self._input_layers[i] + input_shape = input_shapes[i] + # It's an input layer: then `compute_output_shape` is identity, + # and there is only one node and one tensor output. + shape_key = layer.name + '_0_0' + layers_to_output_shapes[shape_key] = input_shape + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + # Iterate over nodes, by depth level. + if len(depth_keys) > 1: + for depth in depth_keys: + nodes = self._nodes_by_depth[depth] + for node in nodes: + # This is always a single layer, never a list. + layer = node.outbound_layer + if layer in self._input_layers: + # We've already covered the input layers + # a few lines above. + continue + # Potentially redundant list, + # same size as node.input_tensors. + input_shapes = [] + for j in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[j] + node_index = node.node_indices[j] + tensor_index = node.tensor_indices[j] + shape_key = inbound_layer.name + '_%s_%s' % (node_index, + tensor_index) + input_shape = layers_to_output_shapes[shape_key] + input_shapes.append(input_shape) + + if len(input_shapes) == 1: + output_shape = layer.compute_output_shape(input_shapes[0]) + else: + output_shape = layer.compute_output_shape(input_shapes) + if isinstance(output_shape, list): + output_shapes = [ + tuple(tensor_shape.TensorShape(shape).as_list()) + for shape in output_shape + ] + else: + output_shapes = [ + tuple(tensor_shape.TensorShape(output_shape).as_list()) + ] + + node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access + for j in range(len(output_shapes)): + shape_key = layer.name + '_%s_%s' % (node_index, j) + layers_to_output_shapes[shape_key] = output_shapes[j] + + # Read final output shapes from layers_to_output_shapes. + output_shapes = [] + for i in range(len(self._output_layers)): + layer, node_index, tensor_index = self._output_coordinates[i] + shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) + output_shapes.append(layers_to_output_shapes[shape_key]) + # Store in cache. + self._output_shape_cache[cache_key] = output_shapes + else: + # Cache hit. + output_shapes = self._output_shape_cache[cache_key] + + if isinstance(output_shapes, list): + if len(output_shapes) == 1: + return tensor_shape.TensorShape(output_shapes[0]) + else: + return [tensor_shape.TensorShape(shape) for shape in output_shapes] + else: + return tensor_shape.TensorShape(output_shapes) + + def _run_internal_graph(self, inputs, training=None, mask=None): + """Computes output tensors for new inputs. + + # Note: + - Expects `inputs` to be a list (potentially with 1 element). + - Can be run on non-Keras tensors. + + Arguments: + inputs: List of tensors + training: Boolean learning phase. + mask: List of masks (tensors or None). + + Returns: + Three lists: output_tensors, output_masks, output_shapes + """ + # Note: masking support is relevant mainly for Keras. + # It cannot be factored out without having the fully reimplement the network + # calling logic on the Keras side. We choose to incorporate it in + # Network because 1) it may be useful to fully support in tf.layers in + # the future and 2) Keras is a major user of Network. If you don't + # use masking, it does not interfere with regular behavior at all and you + # can ignore it. + if mask is None: + masks = [None for _ in range(len(inputs))] + else: + masks = mask + + # Dictionary mapping reference tensors to tuples + # (computed tensor, compute mask) + # we assume a 1:1 mapping from tensor to mask + # TODO(fchollet): raise exception when a `.compute_mask()` call + # does not return a list the same size as `call` + tensor_map = {} + for x, y, mask in zip(self.inputs, inputs, masks): + tensor_map[str(id(x))] = (y, mask) + + depth_keys = list(self._nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + for depth in depth_keys: + nodes = self._nodes_by_depth[depth] + for node in nodes: + # This is always a single layer, never a list. + layer = node.outbound_layer + reference_input_tensors = node.input_tensors + reference_output_tensors = node.output_tensors + + # If all previous input tensors are available in tensor_map, + # then call node.inbound_layer on them. + computed_data = [] # List of tuples (input, mask). + for x in reference_input_tensors: + if str(id(x)) in tensor_map: + computed_data.append(tensor_map[str(id(x))]) + + if len(computed_data) == len(reference_input_tensors): + # Call layer (reapplying ops to new inputs). + with ops.name_scope(layer.name): + if node.arguments: + kwargs = node.arguments + else: + kwargs = {} + if len(computed_data) == 1: + computed_tensor, computed_mask = computed_data[0] + # Ensure mask propagation if applicable. + if 'mask' in tf_inspect.getargspec(layer.call).args: + kwargs.setdefault('mask', computed_mask) + if 'training' in tf_inspect.getargspec(layer.call).args: + kwargs.setdefault('training', training) + + output_tensors = nest.flatten( + layer.call(computed_tensor, **kwargs)) + if hasattr(layer, 'compute_mask'): + output_masks = nest.flatten( + layer.compute_mask(computed_tensor, computed_mask)) + else: + output_masks = [None for _ in range(len(output_tensors))] + computed_tensors = [computed_tensor] + computed_masks = [computed_mask] + else: + computed_tensors = [x[0] for x in computed_data] + computed_masks = [x[1] for x in computed_data] + if 'mask' in tf_inspect.getargspec(layer.call).args: + kwargs.setdefault('mask', computed_masks) + if 'training' in tf_inspect.getargspec(layer.call).args: + kwargs.setdefault('training', training) + + output_tensors = nest.flatten( + layer.call(computed_tensors, **kwargs)) + if hasattr(layer, 'compute_mask'): + output_masks = nest.flatten( + layer.compute_mask(computed_tensors, computed_masks)) + else: + output_masks = [None for _ in range(len(output_tensors))] + + if context.in_graph_mode(): + if layer.activity_regularizer is not None: + regularization_losses = [ + layer.activity_regularizer(x) for x in output_tensors + ] + # Apply activity regularizer if any: + layer.add_loss(regularization_losses, computed_tensors) + + # Update tensor_map. + for x, y, mask in zip(reference_output_tensors, output_tensors, + output_masks): + tensor_map[str(id(x))] = (y, mask) + + output_tensors = [] + output_masks = [] + output_shapes = [] + for x in self.outputs: + assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) + tensor, mask = tensor_map[str(id(x))] + output_shapes.append(tf_layers_util.static_shape(x)) + output_tensors.append(tensor) + output_masks.append(mask) + + if len(output_tensors) == 1: + output_tensors = output_tensors[0] + if output_shapes is not None: + output_shapes = output_shapes[0] + if output_masks is not None: + output_masks = output_masks[0] + + if context.in_graph_mode(): + # Update cache; + # keys are based on ids on input tensors and inputs masks. + cache_key = (tf_layers_util.object_list_uid(inputs) + + '_' + tf_layers_util.object_list_uid(masks)) + self._output_tensor_cache[cache_key] = output_tensors + self._output_mask_cache[cache_key] = output_masks + + if output_shapes is not None: + input_shapes = [tf_layers_util.static_shape(x) for x in inputs] + cache_key = tf_layers_util.object_list_uid(input_shapes) + self._output_shape_cache[cache_key] = output_shapes + + return output_tensors, output_masks + + def get_config(self): + if not self._is_graph_network: + raise NotImplementedError + + config = { + 'name': self.name, + } + node_conversion_map = {} + for layer in self.layers: + if issubclass(layer.__class__, Network): + # Networks start with a pre-existing node + # linking their input to output. + kept_nodes = 1 + else: + kept_nodes = 0 + for original_node_index, node in enumerate(layer._inbound_nodes): + node_key = _make_node_key(layer.name, original_node_index) + if node_key in self._network_nodes: + node_conversion_map[node_key] = kept_nodes + kept_nodes += 1 + layer_configs = [] + for layer in self.layers: # From the earliest layers on. + layer_class_name = layer.__class__.__name__ + layer_config = layer.get_config() + filtered_inbound_nodes = [] + for original_node_index, node in enumerate(layer._inbound_nodes): + node_key = _make_node_key(layer.name, original_node_index) + if node_key in self._network_nodes: + # The node is relevant to the model: + # add to filtered_inbound_nodes. + if node.arguments: + try: + json.dumps(node.arguments) + kwargs = node.arguments + except TypeError: + logging.warning( + 'Layer ' + layer.name + + ' was passed non-serializable keyword arguments: ' + + str(node.arguments) + '. They will not be included ' + 'in the serialized model (and thus will be missing ' + 'at deserialization time).') + kwargs = {} + else: + kwargs = {} + if node.inbound_layers: + node_data = [] + for i in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[i] + node_index = node.node_indices[i] + tensor_index = node.tensor_indices[i] + node_key = _make_node_key(inbound_layer.name, node_index) + new_node_index = node_conversion_map.get(node_key, 0) + node_data.append( + [inbound_layer.name, new_node_index, tensor_index, kwargs]) + filtered_inbound_nodes.append(node_data) + layer_configs.append({ + 'name': layer.name, + 'class_name': layer_class_name, + 'config': layer_config, + 'inbound_nodes': filtered_inbound_nodes, + }) + config['layers'] = layer_configs + + # Gather info about inputs and outputs. + model_inputs = [] + for i in range(len(self._input_layers)): + layer, node_index, tensor_index = self._input_coordinates[i] + node_key = _make_node_key(layer.name, node_index) + if node_key not in self._network_nodes: + continue + new_node_index = node_conversion_map[node_key] + model_inputs.append([layer.name, new_node_index, tensor_index]) + config['input_layers'] = model_inputs + model_outputs = [] + for i in range(len(self._output_layers)): + layer, node_index, tensor_index = self._output_coordinates[i] + node_key = _make_node_key(layer.name, node_index) + if node_key not in self._network_nodes: + continue + new_node_index = node_conversion_map[node_key] + model_outputs.append([layer.name, new_node_index, tensor_index]) + config['output_layers'] = model_outputs + return copy.deepcopy(config) + + @classmethod + def from_config(cls, config, custom_objects=None): + """Instantiates a Model from its config (output of `get_config()`). + + Arguments: + config: Model config dictionary. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A model instance. + + Raises: + ValueError: In case of improperly formatted config dict. + """ + # Layer instances created during + # the graph reconstruction process + created_layers = {} + + # Dictionary mapping layer instances to + # node data that specifies a layer call. + # It acts as a queue that maintains any unprocessed + # layer call until it becomes possible to process it + # (i.e. until the input tensors to the call all exist). + unprocessed_nodes = {} + + def add_unprocessed_node(layer, node_data): + if layer not in unprocessed_nodes: + unprocessed_nodes[layer] = [node_data] + else: + unprocessed_nodes[layer].append(node_data) + + def process_node(layer, node_data): + """Deserialize a node. + + Arguments: + layer: layer instance. + node_data: node config dict. + + Raises: + ValueError: In case of improperly formatted `node_data` dict. + """ + input_tensors = [] + for input_data in node_data: + inbound_layer_name = input_data[0] + inbound_node_index = input_data[1] + inbound_tensor_index = input_data[2] + if len(input_data) == 3: + kwargs = {} + elif len(input_data) == 4: + kwargs = input_data[3] + else: + raise ValueError('Improperly formatted model config.') + if inbound_layer_name not in created_layers: + add_unprocessed_node(layer, node_data) + return + inbound_layer = created_layers[inbound_layer_name] + if len(inbound_layer._inbound_nodes) <= inbound_node_index: + add_unprocessed_node(layer, node_data) + return + inbound_node = inbound_layer._inbound_nodes[inbound_node_index] + input_tensors.append(inbound_node.output_tensors[inbound_tensor_index]) + # Call layer on its inputs, thus creating the node + # and building the layer if needed. + if input_tensors: + if len(input_tensors) == 1: + layer(input_tensors[0], **kwargs) + else: + layer(input_tensors, **kwargs) + + def process_layer(layer_data): + """Deserialize a layer, then call it on appropriate inputs. + + Arguments: + layer_data: layer config dict. + + Raises: + ValueError: In case of improperly formatted `layer_data` dict. + """ + layer_name = layer_data['name'] + + # Instantiate layer. + from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top + + layer = deserialize_layer(layer_data, custom_objects=custom_objects) + created_layers[layer_name] = layer + + # Gather layer inputs. + inbound_nodes_data = layer_data['inbound_nodes'] + for node_data in inbound_nodes_data: + # We don't process nodes (i.e. make layer calls) + # on the fly because the inbound node may not yet exist, + # in case of layer shared at different topological depths + # (e.g. a model such as A(B(A(B(x))))) + add_unprocessed_node(layer, node_data) + + # First, we create all layers and enqueue nodes to be processed + for layer_data in config['layers']: + process_layer(layer_data) + # Then we process nodes in order of layer depth. + # Nodes that cannot yet be processed (if the inbound node + # does not yet exist) are re-enqueued, and the process + # is repeated until all nodes are processed. + while unprocessed_nodes: + for layer_data in config['layers']: + layer = created_layers[layer_data['name']] + if layer in unprocessed_nodes: + for node_data in unprocessed_nodes.pop(layer): + process_node(layer, node_data) + + name = config.get('name') + input_tensors = [] + output_tensors = [] + for layer_data in config['input_layers']: + layer_name, node_index, tensor_index = layer_data + assert layer_name in created_layers + layer = created_layers[layer_name] + layer_output_tensors = layer._inbound_nodes[node_index].output_tensors + input_tensors.append(layer_output_tensors[tensor_index]) + for layer_data in config['output_layers']: + layer_name, node_index, tensor_index = layer_data + assert layer_name in created_layers + layer = created_layers[layer_name] + layer_output_tensors = layer._inbound_nodes[node_index].output_tensors + output_tensors.append(layer_output_tensors[tensor_index]) + return cls(inputs=input_tensors, outputs=output_tensors, name=name) + + def save(self, filepath, overwrite=True, include_optimizer=True): + """Save the model to a single HDF5 file. + + The savefile includes: + - The model architecture, allowing to re-instantiate the model. + - The model weights. + - The state of the optimizer, allowing to resume training + exactly where you left off. + + This allows you to save the entirety of the state of a model + in a single file. + + Saved models can be reinstantiated via `keras.models.load_model`. + The model returned by `load_model` + is a compiled model ready to be used (unless the saved model + was never compiled in the first place). + + Arguments: + filepath: String, path to the file to save the weights to. + overwrite: Whether to silently overwrite any existing file at the + target location, or provide the user with a manual prompt. + include_optimizer: If True, save optimizer's state together. + + Example: + + ```python + from keras.models import load_model + + model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' + del model # deletes the existing model + + # returns a compiled model + # identical to the previous one + model = load_model('my_model.h5') + ``` + """ + if not self._is_graph_network: + raise NotImplementedError + + from tensorflow.python.keras._impl.keras.models import save_model # pylint: disable=g-import-not-at-top + save_model(self, filepath, overwrite, include_optimizer) + + def save_weights(self, filepath, overwrite=True): + """Dumps all layer weights to a HDF5 file. + + The weight file has: + - `layer_names` (attribute), a list of strings + (ordered names of model layers). + - For every layer, a `group` named `layer.name` + - For every such layer group, a group attribute `weight_names`, + a list of strings + (ordered names of weights tensor of the layer). + - For every weight in the layer, a dataset + storing the weight value, named after the weight tensor. + + Arguments: + filepath: String, path to the file to save the weights to. + overwrite: Whether to silently overwrite any existing file at the + target location, or provide the user with a manual prompt. + + Raises: + ImportError: If h5py is not available. + """ + if h5py is None: + raise ImportError('`save_weights` requires h5py.') + # If file exists and should not be overwritten: + if not overwrite and os.path.isfile(filepath): + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + with h5py.File(filepath, 'w') as f: + saving.save_weights_to_hdf5_group(f, self.layers) + + def load_weights(self, filepath, by_name=False): + """Loads all layer weights from a HDF5 save file. + + If `by_name` is False (default) weights are loaded + based on the network's topology, meaning the architecture + should be the same as when the weights were saved. + Note that layers that don't have weights are not taken + into account in the topological ordering, so adding or + removing layers is fine as long as they don't have weights. + + If `by_name` is True, weights are loaded into layers + only if they share the same name. This is useful + for fine-tuning or transfer-learning models where + some of the layers have changed. + + Arguments: + filepath: String, path to the weights file to load. + by_name: Boolean, whether to load weights by name + or by topological order. + + Raises: + ImportError: If h5py is not available. + """ + if h5py is None: + raise ImportError('`load_weights` requires h5py.') + with h5py.File(filepath, 'r') as f: + if 'layer_names' not in f.attrs and 'model_weights' in f: + f = f['model_weights'] + if by_name: + saving.load_weights_from_hdf5_group_by_name(f, self.layers) + else: + saving.load_weights_from_hdf5_group(f, self.layers) + + def _updated_config(self): + """Util hared between different serialization methods. + + Returns: + Model config with Keras version information added. + """ + from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + + config = self.get_config() + model_config = { + 'class_name': self.__class__.__name__, + 'config': config, + 'keras_version': keras_version, + 'backend': K.backend() + } + return model_config + + def to_json(self, **kwargs): + """Returns a JSON string containing the network configuration. + + To load a network from a JSON save file, use + `keras.models.model_from_json(json_string, custom_objects={})`. + + Arguments: + **kwargs: Additional keyword arguments + to be passed to `json.dumps()`. + + Returns: + A JSON string. + """ + if not self._is_graph_network: + raise NotImplementedError + + def get_json_type(obj): + # If obj is any numpy type + if type(obj).__module__ == np.__name__: + return obj.item() + + # If obj is a python 'type' + if type(obj).__name__ == type.__name__: + return obj.__name__ + + raise TypeError('Not JSON Serializable:', obj) + + model_config = self._updated_config() + return json.dumps(model_config, default=get_json_type, **kwargs) + + def to_yaml(self, **kwargs): + """Returns a yaml string containing the network configuration. + + To load a network from a yaml save file, use + `keras.models.model_from_yaml(yaml_string, custom_objects={})`. + + `custom_objects` should be a dictionary mapping + the names of custom losses / layers / etc to the corresponding + functions / classes. + + Arguments: + **kwargs: Additional keyword arguments + to be passed to `yaml.dump()`. + + Returns: + A YAML string. + + Raises: + ImportError: if yaml module is not found. + """ + if not self._is_graph_network: + raise NotImplementedError + + if yaml is None: + raise ImportError('Requires yaml module installed.') + return yaml.dump(self._updated_config(), **kwargs) + + def summary(self, line_length=None, positions=None, print_fn=None): + """Prints a string summary of the network. + + Arguments: + line_length: Total length of printed lines + (e.g. set this to adapt the display to different + terminal window sizes). + positions: Relative or absolute positions of log elements + in each line. If not provided, + defaults to `[.33, .55, .67, 1.]`. + print_fn: Print function to use. Defaults to `print`. + It will be called on each line of the summary. + You can set it to a custom function + in order to capture the string summary. + """ + print_layer_summary(self, + line_length=line_length, + positions=positions, + print_fn=print_fn) + + +def get_source_inputs(tensor, layer=None, node_index=None): + """Returns the list of input tensors necessary to compute `tensor`. + + Output will always be a list of tensors + (potentially with 1 element). + + Arguments: + tensor: The tensor to start from. + layer: Origin layer of the tensor. Will be + determined via tensor._keras_history if not provided. + node_index: Origin node index of the tensor. + + Returns: + List of input tensors. + """ + if not hasattr(tensor, '_keras_history'): + return tensor + + if layer is None or node_index: + layer, node_index, _ = tensor._keras_history + if not layer._inbound_nodes: + return [tensor] + else: + node = layer._inbound_nodes[node_index] + if not node.inbound_layers: + # Reached an Input layer, stop recursion. + return node.input_tensors + else: + source_tensors = [] + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + previous_sources = get_source_inputs(x, layer, node_index) + # Avoid input redundancy. + for x in previous_sources: + if x not in source_tensors: + source_tensors.append(x) + return source_tensors + + +def _make_node_key(layer_name, node_index): + return layer_name + '_ib-' + str(node_index) + + +def _map_graph_network(inputs, outputs): + """Validate a network's topology and gather its layers and nodes. + + Arguments: + inputs: List of input tensors. + outputs: List of outputs tensors. + + Returns: + A tuple `(nodes, nodes_by_depth, layers, layers_by_depth)`. + - nodes: list of Node instances. + - nodes_by_depth: dict mapping ints (depth) to lists of node instances. + - layers: list of Layer instances. + - layers_by_depth: dict mapping ints (depth) to lists of layer instances. + + Raises: + ValueError: In case the network is not valid (e.g. disconnected graph). + """ + # Network_nodes: set of nodes included in the graph of layers + # (not all nodes included in the layers are relevant to the current graph). + network_nodes = set() # ids of all nodes relevant to the Network + nodes_depths = {} # dict {node: depth value} + layers_depths = {} # dict {layer: depth value} + layer_indices = {} # dict {layer: index in traversal} + nodes_in_decreasing_depth = [] + + def build_map(tensor, + finished_nodes, + nodes_in_progress, + layer, + node_index, + tensor_index): + """Builds a map of the graph of layers. + + This recursively updates the map `layer_indices`, + the list `nodes_in_decreasing_depth` and the set `network_nodes`. + + Arguments: + tensor: Some tensor in a graph. + finished_nodes: Set of nodes whose subgraphs have been traversed + completely. Useful to prevent duplicated work. + nodes_in_progress: Set of nodes that are currently active on the + recursion stack. Useful to detect cycles. + layer: Layer from which `tensor` comes from. If not provided, + will be obtained from `tensor._keras_history`. + node_index: Node index from which `tensor` comes from. + tensor_index: Tensor_index from which `tensor` comes from. + + Raises: + ValueError: if a cycle is detected. + """ + node = layer._inbound_nodes[node_index] # pylint: disable=protected-access + + # Prevent cycles. + if node in nodes_in_progress: + raise ValueError('The tensor ' + str(tensor) + ' at layer "' + + layer.name + '" is part of a cycle.') + + # Don't repeat work for shared subgraphs + if node in finished_nodes: + return + + node_key = _make_node_key(layer.name, node_index) + # Update network_nodes. + network_nodes.add(node_key) + + # Store the traversal order for layer sorting. + if layer not in layer_indices: + layer_indices[layer] = len(layer_indices) + + nodes_in_progress.add(node) + + # Propagate to all previous tensors connected to this node. + for i in range(len(node.inbound_layers)): + x = node.input_tensors[i] + layer = node.inbound_layers[i] + node_index = node.node_indices[i] + tensor_index = node.tensor_indices[i] + build_map(x, finished_nodes, nodes_in_progress, layer, + node_index, tensor_index) + + finished_nodes.add(node) + nodes_in_progress.remove(node) + nodes_in_decreasing_depth.append(node) + + finished_nodes = set() + nodes_in_progress = set() + for x in outputs: + layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access + build_map(x, finished_nodes, nodes_in_progress, + layer=layer, + node_index=node_index, + tensor_index=tensor_index) + + for node in reversed(nodes_in_decreasing_depth): + # If the depth is not set, the node has no outbound nodes (depth 0). + depth = nodes_depths.setdefault(node, 0) + + # Update the depth of the corresponding layer + previous_depth = layers_depths.get(node.outbound_layer, 0) + # If we've seen this layer before at a higher depth, + # we should use that depth instead of the node depth. + # This is necessary for shared layers that have inputs at different + # depth levels in the graph. + depth = max(depth, previous_depth) + layers_depths[node.outbound_layer] = depth + nodes_depths[node] = depth + + # Update the depth of inbound nodes. + # The "depth" of a node is the max of the depths + # of all layers it is connected to. + for i in range(len(node.inbound_layers)): + inbound_layer = node.inbound_layers[i] + node_index = node.node_indices[i] + inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access + previous_depth = nodes_depths.get(inbound_node, 0) + nodes_depths[inbound_node] = max(depth + 1, previous_depth) + + # Build a dict {depth: list of nodes with this depth} + nodes_by_depth = {} + for node, depth in nodes_depths.items(): + if depth not in nodes_by_depth: + nodes_by_depth[depth] = [] + nodes_by_depth[depth].append(node) + + # Build a dict {depth: list of layers with this depth} + layers_by_depth = {} + for layer, depth in layers_depths.items(): + if depth not in layers_by_depth: + layers_by_depth[depth] = [] + layers_by_depth[depth].append(layer) + + # Get sorted list of layer depths. + depth_keys = list(layers_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Set self.layers and self._layers_by_depth. + layers = [] + for depth in depth_keys: + layers_for_depth = layers_by_depth[depth] + # Network.layers needs to have a deterministic order: + # here we order them by traversal order. + layers_for_depth.sort(key=lambda x: layer_indices[x]) + layers.extend(layers_for_depth) + + # Get sorted list of node depths. + depth_keys = list(nodes_by_depth.keys()) + depth_keys.sort(reverse=True) + + # Check that all tensors required are computable. + # computable_tensors: all tensors in the graph + # that can be computed from the inputs provided. + computable_tensors = [] + for x in inputs: + computable_tensors.append(x) + + layers_with_complete_input = [] # To provide a better error msg. + for depth in depth_keys: + for node in nodes_by_depth[depth]: + layer = node.outbound_layer + if layer: + for x in node.input_tensors: + if x not in computable_tensors: + raise ValueError('Graph disconnected: ' + 'cannot obtain value for tensor ' + str(x) + + ' at layer "' + layer.name + '". ' + 'The following previous layers ' + 'were accessed without issue: ' + + str(layers_with_complete_input)) + for x in node.output_tensors: + computable_tensors.append(x) + layers_with_complete_input.append(layer.name) + + # Ensure name unicity, which will be crucial for serialization + # (since serialized nodes refer to layers by their name). + all_names = [layer.name for layer in layers] + for name in all_names: + if all_names.count(name) != 1: + raise ValueError('The name "' + name + '" is used ' + + str(all_names.count(name)) + ' times in the model. ' + 'All layer names should be unique.') + return network_nodes, nodes_by_depth, layers, layers_by_depth diff --git a/tensorflow/python/keras/_impl/keras/engine/saving.py b/tensorflow/python/keras/_impl/keras/engine/saving.py new file mode 100644 index 0000000000000000000000000000000000000000..52522e693511b010d0501651e594d346984c41e3 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/saving.py @@ -0,0 +1,671 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Model saving utilities. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import os + +import numpy as np +from six.moves import zip # pylint: disable=redefined-builtin + +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import optimizers +from tensorflow.python.keras._impl.keras.utils import conv_utils +from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + +# pylint: disable=g-import-not-at-top +try: + import h5py +except ImportError: + h5py = None + +try: + import yaml +except ImportError: + yaml = None +# pylint: enable=g-import-not-at-top + + +@tf_export('keras.models.save_model') +def save_model(model, filepath, overwrite=True, include_optimizer=True): + """Save a model to a HDF5 file. + + The saved model contains: + - the model's configuration (topology) + - the model's weights + - the model's optimizer's state (if any) + + Thus the saved model can be reinstantiated in + the exact same state, without any of the code + used for model definition or training. + + Arguments: + model: Keras model instance to be saved. + filepath: String, path where to save the model. + overwrite: Whether we should overwrite any existing + model at the target location, or instead + ask the user with a manual prompt. + include_optimizer: If True, save optimizer's state together. + + Raises: + ImportError: if h5py is not available. + """ + + if h5py is None: + raise ImportError('`save_model` requires h5py.') + + def get_json_type(obj): + """Serialize any object to a JSON-serializable structure. + + Arguments: + obj: the object to serialize + + Returns: + JSON-serializable structure representing `obj`. + + Raises: + TypeError: if `obj` cannot be serialized. + """ + # if obj is a serializable Keras class instance + # e.g. optimizer, layer + if hasattr(obj, 'get_config'): + return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} + + # if obj is any numpy type + if type(obj).__module__ == np.__name__: + if isinstance(obj, np.ndarray): + return {'type': type(obj), 'value': obj.tolist()} + else: + return obj.item() + + # misc functions (e.g. loss function) + if callable(obj): + return obj.__name__ + + # if obj is a python 'type' + if type(obj).__name__ == type.__name__: + return obj.__name__ + + raise TypeError('Not JSON Serializable:', obj) + + from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + + # If file exists and should not be overwritten. + if not overwrite and os.path.isfile(filepath): + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + with h5py.File(filepath, mode='w') as f: + f.attrs['keras_version'] = str(keras_version).encode('utf8') + f.attrs['backend'] = K.backend().encode('utf8') + f.attrs['model_config'] = json.dumps( + { + 'class_name': model.__class__.__name__, + 'config': model.get_config() + }, + default=get_json_type).encode('utf8') + + model_weights_group = f.create_group('model_weights') + model_layers = model.layers + save_weights_to_hdf5_group(model_weights_group, model_layers) + + if include_optimizer and hasattr(model, 'optimizer'): + if isinstance(model.optimizer, optimizers.TFOptimizer): + logging.warning( + 'TensorFlow optimizers do not ' + 'make it possible to access ' + 'optimizer attributes or optimizer state ' + 'after instantiation. ' + 'As a result, we cannot save the optimizer ' + 'as part of the model save file.' + 'You will have to compile your model again after loading it. ' + 'Prefer using a Keras optimizer instead ' + '(see keras.io/optimizers).') + else: + f.attrs['training_config'] = json.dumps( + { + 'optimizer_config': { + 'class_name': model.optimizer.__class__.__name__, + 'config': model.optimizer.get_config() + }, + 'loss': model.loss, + 'metrics': model.metrics, + 'sample_weight_mode': model.sample_weight_mode, + 'loss_weights': model.loss_weights, + }, + default=get_json_type).encode('utf8') + + # Save optimizer weights. + symbolic_weights = getattr(model.optimizer, 'weights') + if symbolic_weights: + optimizer_weights_group = f.create_group('optimizer_weights') + weight_values = K.batch_get_value(symbolic_weights) + weight_names = [] + for w, val in zip(symbolic_weights, weight_values): + name = str(w.name) + weight_names.append(name.encode('utf8')) + optimizer_weights_group.attrs['weight_names'] = weight_names + for name, val in zip(weight_names, weight_values): + param_dset = optimizer_weights_group.create_dataset( + name, val.shape, dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + f.flush() + + +@tf_export('keras.models.load_model') +def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin + """Loads a model saved via `save_model`. + + Arguments: + filepath: String, path to the saved model. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + compile: Boolean, whether to compile the model + after loading. + + Returns: + A Keras model instance. If an optimizer was found + as part of the saved model, the model is already + compiled. Otherwise, the model is uncompiled and + a warning will be displayed. When `compile` is set + to False, the compilation is omitted without any + warning. + + Raises: + ImportError: if h5py is not available. + ValueError: In case of an invalid savefile. + """ + if h5py is None: + raise ImportError('`load_model` requires h5py.') + + if not custom_objects: + custom_objects = {} + + def convert_custom_objects(obj): + """Handles custom object lookup. + + Arguments: + obj: object, dict, or list. + + Returns: + The same structure, where occurrences + of a custom object name have been replaced + with the custom object. + """ + if isinstance(obj, list): + deserialized = [] + for value in obj: + deserialized.append(convert_custom_objects(value)) + return deserialized + if isinstance(obj, dict): + deserialized = {} + for key, value in obj.items(): + deserialized[key] = convert_custom_objects(value) + return deserialized + if obj in custom_objects: + return custom_objects[obj] + return obj + + with h5py.File(filepath, mode='r') as f: + # instantiate model + model_config = f.attrs.get('model_config') + if model_config is None: + raise ValueError('No model found in config file.') + model_config = json.loads(model_config.decode('utf-8')) + model = model_from_config(model_config, custom_objects=custom_objects) + + # set weights + load_weights_from_hdf5_group(f['model_weights'], model.layers) + + # Early return if compilation is not required. + if not compile: + return model + + # instantiate optimizer + training_config = f.attrs.get('training_config') + if training_config is None: + logging.warning('No training configuration found in save file: ' + 'the model was *not* compiled. Compile it manually.') + return model + training_config = json.loads(training_config.decode('utf-8')) + optimizer_config = training_config['optimizer_config'] + optimizer = optimizers.deserialize( + optimizer_config, custom_objects=custom_objects) + + # Recover loss functions and metrics. + loss = convert_custom_objects(training_config['loss']) + metrics = convert_custom_objects(training_config['metrics']) + sample_weight_mode = training_config['sample_weight_mode'] + loss_weights = training_config['loss_weights'] + + # Compile model. + model.compile( + optimizer=optimizer, + loss=loss, + metrics=metrics, + loss_weights=loss_weights, + sample_weight_mode=sample_weight_mode) + + # Set optimizer weights. + if 'optimizer_weights' in f: + # Build train function (to get weight updates). + model._make_train_function() + optimizer_weights_group = f['optimizer_weights'] + optimizer_weight_names = [ + n.decode('utf8') + for n in optimizer_weights_group.attrs['weight_names'] + ] + optimizer_weight_values = [ + optimizer_weights_group[n] for n in optimizer_weight_names + ] + try: + model.optimizer.set_weights(optimizer_weight_values) + except ValueError: + logging.warning('Error in loading the saved optimizer ' + 'state. As a result, your model is ' + 'starting with a freshly initialized ' + 'optimizer.') + return model + + +@tf_export('keras.models.model_from_config') +def model_from_config(config, custom_objects=None): + """Instantiates a Keras model from its config. + + Arguments: + config: Configuration dictionary. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + + Raises: + TypeError: if `config` is not a dictionary. + """ + if isinstance(config, list): + raise TypeError('`model_from_config` expects a dictionary, not a list. ' + 'Maybe you meant to use ' + '`Sequential.from_config(config)`?') + from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + return deserialize(config, custom_objects=custom_objects) + + +@tf_export('keras.models.model_from_yaml') +def model_from_yaml(yaml_string, custom_objects=None): + """Parses a yaml model configuration file and returns a model instance. + + Arguments: + yaml_string: YAML string encoding a model configuration. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + + Raises: + ImportError: if yaml module is not found. + """ + if yaml is None: + raise ImportError('Requires yaml module installed.') + config = yaml.load(yaml_string) + from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + return deserialize(config, custom_objects=custom_objects) + + +@tf_export('keras.models.model_from_json') +def model_from_json(json_string, custom_objects=None): + """Parses a JSON model configuration file and returns a model instance. + + Arguments: + json_string: JSON string encoding a model configuration. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + + Returns: + A Keras model instance (uncompiled). + """ + config = json.loads(json_string) + from tensorflow.python.keras._impl.keras.layers import deserialize # pylint: disable=g-import-not-at-top + return deserialize(config, custom_objects=custom_objects) + + +def save_weights_to_hdf5_group(f, layers): + from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top + + f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers] + f.attrs['backend'] = K.backend().encode('utf8') + f.attrs['keras_version'] = str(keras_version).encode('utf8') + + for layer in layers: + g = f.create_group(layer.name) + symbolic_weights = layer.weights + weight_values = K.batch_get_value(symbolic_weights) + weight_names = [] + for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): + if hasattr(w, 'name') and w.name: + name = str(w.name) + else: + name = 'param_' + str(i) + weight_names.append(name.encode('utf8')) + g.attrs['weight_names'] = weight_names + for name, val in zip(weight_names, weight_values): + param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) + if not val.shape: + # scalar + param_dset[()] = val + else: + param_dset[:] = val + + +def preprocess_weights_for_loading(layer, + weights, + original_keras_version=None, + original_backend=None): + """Converts layers weights from Keras 1 format to Keras 2. + + Arguments: + layer: Layer instance. + weights: List of weights values (Numpy arrays). + original_keras_version: Keras version for the weights, as a string. + original_backend: Keras backend the weights were trained with, + as a string. + + Returns: + A list of weights values (Numpy arrays). + """ + if layer.__class__.__name__ == 'Bidirectional': + num_weights_per_layer = len(weights) // 2 + forward_weights = preprocess_weights_for_loading( + layer.forward_layer, weights[:num_weights_per_layer], + original_keras_version, original_backend) + backward_weights = preprocess_weights_for_loading( + layer.backward_layer, weights[num_weights_per_layer:], + original_keras_version, original_backend) + weights = forward_weights + backward_weights + + if original_keras_version == '1': + if layer.__class__.__name__ == 'TimeDistributed': + weights = preprocess_weights_for_loading( + layer.layer, weights, original_keras_version, original_backend) + + if layer.__class__.__name__ == 'Conv1D': + shape = weights[0].shape + # Handle Keras 1.1 format + if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: + # Legacy shape: + # (filters, input_dim, filter_length, 1) + assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], + 1) + weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) + weights[0] = weights[0][:, 0, :, :] + + if layer.__class__.__name__ == 'Conv2D': + if layer.data_format == 'channels_first': + # old: (filters, stack_size, kernel_rows, kernel_cols) + # new: (kernel_rows, kernel_cols, stack_size, filters) + weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) + + if layer.__class__.__name__ == 'Conv2DTranspose': + if layer.data_format == 'channels_last': + # old: (kernel_rows, kernel_cols, stack_size, filters) + # new: (kernel_rows, kernel_cols, filters, stack_size) + weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) + if layer.data_format == 'channels_first': + # old: (filters, stack_size, kernel_rows, kernel_cols) + # new: (kernel_rows, kernel_cols, filters, stack_size) + weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) + + if layer.__class__.__name__ == 'Conv3D': + if layer.data_format == 'channels_first': + # old: (filters, stack_size, ...) + # new: (..., stack_size, filters) + weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) + + if layer.__class__.__name__ == 'GRU': + if len(weights) == 9: + kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) + recurrent_kernel = np.concatenate( + [weights[1], weights[4], weights[7]], axis=-1) + bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) + weights = [kernel, recurrent_kernel, bias] + + if layer.__class__.__name__ == 'LSTM': + if len(weights) == 12: + # old: i, c, f, o + # new: i, f, c, o + kernel = np.concatenate( + [weights[0], weights[6], weights[3], weights[9]], axis=-1) + recurrent_kernel = np.concatenate( + [weights[1], weights[7], weights[4], weights[10]], axis=-1) + bias = np.concatenate( + [weights[2], weights[8], weights[5], weights[11]], axis=-1) + weights = [kernel, recurrent_kernel, bias] + + if layer.__class__.__name__ == 'ConvLSTM2D': + if len(weights) == 12: + kernel = np.concatenate( + [weights[0], weights[6], weights[3], weights[9]], axis=-1) + recurrent_kernel = np.concatenate( + [weights[1], weights[7], weights[4], weights[10]], axis=-1) + bias = np.concatenate( + [weights[2], weights[8], weights[5], weights[11]], axis=-1) + if layer.data_format == 'channels_first': + # old: (filters, stack_size, kernel_rows, kernel_cols) + # new: (kernel_rows, kernel_cols, stack_size, filters) + kernel = np.transpose(kernel, (2, 3, 1, 0)) + recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) + weights = [kernel, recurrent_kernel, bias] + + if layer.__class__.__name__ in ['Model', 'Sequential']: + new_weights = [] + # trainable weights + for sublayer in layer.layers: + num_weights = len(sublayer.trainable_weights) + if num_weights > 0: + new_weights.extend( + preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + + # non-trainable weights + for sublayer in layer.layers: + num_weights = len([ + l for l in sublayer.weights if l not in sublayer.trainable_weights + ]) + if num_weights > 0: + new_weights.extend( + preprocess_weights_for_loading( + layer=sublayer, + weights=weights[:num_weights], + original_keras_version=original_keras_version, + original_backend=original_backend)) + weights = weights[num_weights:] + weights = new_weights + + conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] + if layer.__class__.__name__ in conv_layers: + if original_backend == 'theano': + weights[0] = conv_utils.convert_kernel(weights[0]) + if layer.__class__.__name__ == 'ConvLSTM2D': + weights[1] = conv_utils.convert_kernel(weights[1]) + if K.int_shape(layer.weights[0]) != weights[0].shape: + weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) + if layer.__class__.__name__ == 'ConvLSTM2D': + weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) + + # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM + if layer.__class__.__name__ == 'LSTM' and len(weights) == 3: + # Determine if loading a CuDNNLSTM layer from the number of bias weights: + # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) + # if there's no bias weight in the file, skip this conversion + units = weights[1].shape[0] + bias = weights[2] + if len(bias) == units * 8: + # reshape the kernels + kernels = np.split(weights[0], 4, axis=1) + kernels = [ + kernel.reshape(-1).reshape(kernel.shape, order='F') + for kernel in kernels + ] + weights[0] = np.concatenate(kernels, axis=1) + + # transpose the recurrent kernels + recurrent_kernels = np.split(weights[1], 4, axis=1) + recurrent_kernels = [kernel.T for kernel in recurrent_kernels] + weights[1] = np.concatenate(recurrent_kernels, axis=1) + + # split the bias into half and merge + weights[2] = bias[:units * 4] + bias[units * 4:] + + return weights + + +def load_weights_from_hdf5_group(f, layers): + """Implements topological (order-based) weight loading. + + Arguments: + f: A pointer to a HDF5 group. + layers: a list of target layers. + + Raises: + ValueError: in case of mismatch between provided layers + and weights file. + """ + if 'keras_version' in f.attrs: + original_keras_version = f.attrs['keras_version'].decode('utf8') + else: + original_keras_version = '1' + if 'backend' in f.attrs: + original_backend = f.attrs['backend'].decode('utf8') + else: + original_backend = None + + filtered_layers = [] + for layer in layers: + weights = layer.weights + if weights: + filtered_layers.append(layer) + + layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] + filtered_layer_names = [] + for name in layer_names: + g = f[name] + weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + if weight_names: + filtered_layer_names.append(name) + layer_names = filtered_layer_names + if len(layer_names) != len(filtered_layers): + raise ValueError('You are trying to load a weight file ' + 'containing ' + str(len(layer_names)) + + ' layers into a model with ' + str(len(filtered_layers)) + + ' layers.') + + # We batch weight value assignments in a single backend call + # which provides a speedup in TensorFlow. + weight_value_tuples = [] + for k, name in enumerate(layer_names): + g = f[name] + weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + weight_values = [g[weight_name] for weight_name in weight_names] + layer = filtered_layers[k] + symbolic_weights = layer.weights + weight_values = preprocess_weights_for_loading( + layer, weight_values, original_keras_version, original_backend) + if len(weight_values) != len(symbolic_weights): + raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + + '" in the current model) was found to ' + 'correspond to layer ' + name + ' in the save file. ' + 'However the new layer ' + layer.name + ' expects ' + + str(len(symbolic_weights)) + + ' weights, but the saved weights have ' + + str(len(weight_values)) + ' elements.') + weight_value_tuples += zip(symbolic_weights, weight_values) + K.batch_set_value(weight_value_tuples) + + +def load_weights_from_hdf5_group_by_name(f, layers): + """Implements name-based weight loading. + + (instead of topological weight loading). + + Layers that have no matching name are skipped. + + Arguments: + f: A pointer to a HDF5 group. + layers: a list of target layers. + + Raises: + ValueError: in case of mismatch between provided layers + and weights file. + """ + if 'keras_version' in f.attrs: + original_keras_version = f.attrs['keras_version'].decode('utf8') + else: + original_keras_version = '1' + if 'backend' in f.attrs: + original_backend = f.attrs['backend'].decode('utf8') + else: + original_backend = None + + # New file format. + layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] + + # Reverse index of layer name to list of layers with name. + index = {} + for layer in layers: + if layer.name: + index.setdefault(layer.name, []).append(layer) + + # We batch weight value assignments in a single backend call + # which provides a speedup in TensorFlow. + weight_value_tuples = [] + for k, name in enumerate(layer_names): + g = f[name] + weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] + weight_values = [g[weight_name] for weight_name in weight_names] + + for layer in index.get(name, []): + symbolic_weights = layer.weights + weight_values = preprocess_weights_for_loading( + layer, weight_values, original_keras_version, original_backend) + if len(weight_values) != len(symbolic_weights): + raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + + '") expects ' + str(len(symbolic_weights)) + + ' weight(s), but the saved weights' + ' have ' + + str(len(weight_values)) + ' element(s).') + # Set values. + for i in range(len(weight_values)): + weight_value_tuples.append((symbolic_weights[i], weight_values[i])) + K.batch_set_value(weight_value_tuples) diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb17641b0d26bc227b142d9302dc1da9637c506 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py @@ -0,0 +1,375 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#,============================================================================ +"""Tests for model saving.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test +from tensorflow.python.training import training as training_module + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +class TestWeightSavingAndLoading(test.TestCase): + + def test_weight_loading(self): + with self.test_session(): + a = keras.layers.Input(shape=(2,)) + x = keras.layers.Dense(3)(a) + b = keras.layers.Dense(1)(x) + model = keras.models.Model(a, b) + + x = np.random.random((3, 2)) + ref_y = model.predict(x) + weights = model.get_weights() + model.set_weights(weights) + y = model.predict(x) + self.assertAllClose(ref_y, y) + + with self.assertRaises(ValueError): + model.set_weights(weights[1:]) + with self.assertRaises(ValueError): + model.set_weights(weights[::-1]) + + if h5py is None: + return # Skip rest of test if H5py isn't available. + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + + h5_path = os.path.join(temp_dir, 'test.h5') + model.save_weights(h5_path) + model.load_weights(h5_path) + y = model.predict(x) + self.assertAllClose(ref_y, y) + + model.load_weights(h5_path, by_name=True) + y = model.predict(x) + self.assertAllClose(ref_y, y) + + def test_weight_preprocessing(self): + input_dim = 3 + output_dim = 3 + size = 2 + cases = [ + [ + (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))), + [np.random.random((2, 1)), np.random.random((2, 1))], + (None, 3, 2), + ], + [ + (keras.layers.TimeDistributed(keras.layers.Dense(1))), + [np.random.random((2, 1)), np.random.random((1,))], + (None, 3, 2), + ], + [ + (keras.layers.Conv1D(output_dim, size, use_bias=False)), + [np.random.random((output_dim, input_dim, size, 1))], + (None, 4, input_dim), + ], + [ + (keras.layers.Conv2D(output_dim, size, + use_bias=False, data_format='channels_first')), + [np.random.random((output_dim, input_dim, size, size))], + (None, input_dim, 4, 4), + ], + [ + (keras.layers.Conv2DTranspose(output_dim, size, + use_bias=False, + data_format='channels_first')), + [np.random.random((output_dim, input_dim, size, size))], + (None, input_dim, 4, 4), + ], + [ + (keras.layers.Conv2DTranspose(output_dim, size, + use_bias=False, + data_format='channels_last')), + [np.random.random((size, size, input_dim, output_dim))], + (None, 4, 4, input_dim), + ], + [ + (keras.layers.Conv3D(output_dim, size, + use_bias=False, data_format='channels_first')), + [np.random.random((output_dim, input_dim, size, size, size))], + (None, input_dim, 4, 4, 4), + ], + [ + (keras.layers.GRU(output_dim)), + [np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,))], + (None, 4, input_dim), + ], + [ + (keras.layers.LSTM(output_dim)), + [np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,)), + np.random.random((input_dim, output_dim)), + np.random.random((output_dim, output_dim)), + np.random.random((output_dim,))], + (None, 4, input_dim), + ], + ] + for layer, weights, input_shape in cases: + layer.build(input_shape) + _ = keras.engine.saving.preprocess_weights_for_loading( + layer, weights, original_keras_version='1') + + model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)]) + _ = keras.engine.saving.preprocess_weights_for_loading( + model, model.weights, original_keras_version='1') + + x = keras.Input((2,)) + y = keras.layers.Dense(2)(x) + model = keras.models.Model(x, y) + _ = keras.engine.saving.preprocess_weights_for_loading( + model, model.weights, original_keras_version='1') + + def test_sequential_weight_loading(self): + if h5py is None: + return + + temp_dir = self.get_temp_dir() + self.addCleanup(shutil.rmtree, temp_dir) + h5_path = os.path.join(temp_dir, 'test.h5') + + num_hidden = 5 + input_dim = 3 + batch_size = 5 + num_classes = 2 + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.add(keras.layers.Dense(num_classes)) + + x = np.random.random((batch_size, input_dim)) + ref_y = model.predict(x) + + model.save_weights(h5_path) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.add(keras.layers.Dense(num_classes)) + model.load_weights(h5_path) + y = model.predict(x) + + self.assertAllClose(y, ref_y) + + +class TestWholeModelSaving(test.TestCase): + + def test_sequential_model_saving(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.RepeatVector(3)) + model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy], + sample_weight_mode='temporal') + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + new_model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + out2 = new_model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + # test that new updates are the same with both models + x = np.random.random((1, 3)) + y = np.random.random((1, 3, 3)) + model.train_on_batch(x, y) + new_model.train_on_batch(x, y) + out = model.predict(x) + out2 = new_model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + def test_sequential_model_saving_2(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + # test with custom optimizer, loss + + class CustomOp(keras.optimizers.RMSprop): + pass + + def custom_loss(y_true, y_pred): + return keras.losses.mse(y_true, y_pred) + + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc']) + + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + model = keras.models.load_model( + fname, + custom_objects={'CustomOp': CustomOp, + 'custom_loss': custom_loss}) + os.close(fd) + os.remove(fname) + + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + def test_functional_model_saving(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + inputs = keras.layers.Input(shape=(3,)) + x = keras.layers.Dense(2)(inputs) + output = keras.layers.Dense(3)(x) + + model = keras.models.Model(inputs, output) + model.compile(loss=keras.losses.MSE, + optimizer=keras.optimizers.RMSprop(lr=0.0001), + metrics=[keras.metrics.categorical_accuracy]) + x = np.random.random((1, 3)) + y = np.random.random((1, 3)) + model.train_on_batch(x, y) + + out = model.predict(x) + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + out2 = model.predict(x) + self.assertAllClose(out, out2, atol=1e-05) + + def test_saving_without_compilation(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + def test_saving_with_tf_optimizer(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', + optimizer=training_module.AdadeltaOptimizer(0.1), + metrics=['acc']) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + def test_saving_right_after_compilation(self): + if h5py is None: + return # Skip test if models cannot be saved. + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(2, input_shape=(3,))) + model.add(keras.layers.Dense(3)) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + model.model._make_train_function() + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + def test_saving_lambda_numpy_array_arguments(self): + if h5py is None: + return # Skip test if models cannot be saved. + + mean = np.random.random((4, 2, 3)) + std = np.abs(np.random.random((4, 2, 3))) + 1e-5 + inputs = keras.layers.Input(shape=(4, 2, 3)) + output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, + arguments={'mu': mean, 'std': std})(inputs) + model = keras.models.Model(inputs, output) + model.compile(loss='mse', optimizer='sgd', metrics=['acc']) + + fd, fname = tempfile.mkstemp('.h5') + keras.models.save_model(model, fname) + + model = keras.models.load_model(fname) + os.close(fd) + os.remove(fname) + + self.assertAllClose(mean, model.layers[1].arguments['mu']) + self.assertAllClose(std, model.layers[1].arguments['std']) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential.py b/tensorflow/python/keras/_impl/keras/engine/sequential.py new file mode 100644 index 0000000000000000000000000000000000000000..db5e7754bc22ba360dbf635f1bd80334f58e8509 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/sequential.py @@ -0,0 +1,997 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=protected-access +"""Home of the `Sequential` model. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os + +from tensorflow.python.framework import ops +from tensorflow.python.keras._impl.keras import backend as K +from tensorflow.python.keras._impl.keras import layers as layer_module +from tensorflow.python.keras._impl.keras.engine import base_layer +from tensorflow.python.keras._impl.keras.engine import network +from tensorflow.python.keras._impl.keras.engine import saving +from tensorflow.python.keras._impl.keras.engine.input_layer import Input +from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer +from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util.tf_export import tf_export + +try: + import h5py # pylint: disable=g-import-not-at-top +except ImportError: + h5py = None + + +@tf_export('keras.models.Sequential', 'keras.Sequential') +class Sequential(Model): + """Linear stack of layers. + + Arguments: + layers: list of layers to add to the model. + + # Note + The first layer passed to a Sequential model + should have a defined input shape. What that + means is that it should have received an `input_shape` + or `batch_input_shape` argument, + or for some type of layers (recurrent, Dense...) + an `input_dim` argument. + + Example: + + ```python + model = Sequential() + # first layer must have a defined input shape + model.add(Dense(32, input_dim=500)) + # afterwards, Keras does automatic shape inference + model.add(Dense(32)) + + # also possible (equivalent to the above): + model = Sequential() + model.add(Dense(32, input_shape=(500,))) + model.add(Dense(32)) + + # also possible (equivalent to the above): + model = Sequential() + # here the batch dimension is None, + # which means any batch size will be accepted by the model. + model.add(Dense(32, batch_input_shape=(None, 500))) + model.add(Dense(32)) + ``` + """ + + def __init__(self, layers=None, name=None): + self._is_graph_network = True + self._is_compiled = False + self._layers = [] # Stack of layers. + self.model = None # Internal Model instance. + self.inputs = [] # List of input tensors + self.outputs = [] # List of length 1: the output tensor (unique). + self._trainable = True + self._initial_weights = None + self._input_layers = [] + + # Model attributes. + self._inbound_nodes = [] + self._outbound_nodes = [] + self.built = False + + # Set model name. + if not name: + prefix = 'sequential_' + name = prefix + str(K.get_uid(prefix)) + self._name = name + + # Used by Layer base class. + self._dtype = None + self._activity_regularizer = None + + # The following properties are not actually used by Keras; + # they exist for compatibility with TF's variable scoping mechanism. + self._updates = [] + self._losses = [] + self._scope = None + self._reuse = None + self._base_name = name + self._graph = ops.get_default_graph() + + # Add to the model any layers passed to the constructor. + if layers: + for layer in layers: + self.add(layer) + + def add(self, layer): + """Adds a layer instance on top of the layer stack. + + Arguments: + layer: layer instance. + + Raises: + TypeError: If `layer` is not a layer instance. + ValueError: In case the `layer` argument does not + know its input shape. + ValueError: In case the `layer` argument has + multiple output tensors, or is already connected + somewhere else (forbidden in `Sequential` models). + """ + if not isinstance(layer, (base_layer.Layer, base_layer.TFBaseLayer)): + raise TypeError('The added layer must be ' + 'an instance of class Layer. ' + 'Found: ' + str(layer)) + if not self.outputs: + # First layer in model: check that it is an input layer. + if not isinstance(layer, InputLayer): + # Create an input layer. + # First, we need to infer its expected input shape and dtype. + if isinstance(layer, (Model, Sequential)): + # We were passed a model as first layer. + # This requires a specific way to figure out the + # input shape and dtype. + if not layer.layers: + raise ValueError('Cannot add an empty model ' + 'to a `Sequential` model.') + # In case of nested models: recover the first layer + # of the deepest model to infer input shape and dtype. + first_layer = layer.layers[0] + while isinstance(first_layer, (Model, Sequential)): + first_layer = first_layer.layers[0] + batch_shape = first_layer._batch_input_shape + dtype = first_layer.dtype + else: + # We were passed a regular layer, and it should + # know about its input shape. Otherwise, that's an error. + if not hasattr(layer, '_batch_input_shape'): + raise ValueError('The first layer in a ' + 'Sequential model must ' + 'get an `input_shape` argument.') + batch_shape = layer._batch_input_shape + dtype = layer.dtype + # Instantiate the input layer. + x = Input( + batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') + # This will build the current layer + # and create the node connecting the current layer + # to the input layer we just created. + layer(x) + + if len(layer._inbound_nodes[-1].output_tensors) != 1: + raise ValueError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.') + + self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] + self.inputs = network.get_source_inputs(self.outputs[0]) + + # We create an input node, which we will keep updated + # as we add more layers + base_layer.Node( + outbound_layer=self, + inbound_layers=[], + node_indices=[], + tensor_indices=[], + input_tensors=self.inputs, + output_tensors=self.outputs) + else: + output_tensor = layer(self.outputs[0]) + if isinstance(output_tensor, list): + raise TypeError('All layers in a Sequential model ' + 'should have a single output tensor. ' + 'For multi-output layers, ' + 'use the functional API.') + self.outputs = [output_tensor] + # update self._inbound_nodes + self._inbound_nodes[0].output_tensors = self.outputs + self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] + + self._layers.append(layer) + self.built = False + + def pop(self): + """Removes the last layer in the model. + + Raises: + TypeError: if there are no layers in the model. + """ + if not self.layers: + raise TypeError('There are no layers in the model.') + + self.layers.pop() + if not self.layers: + self.outputs = [] + self._inbound_nodes = [] + self._outbound_nodes = [] + else: + self.layers[-1]._outbound_nodes = [] + self.outputs = [self.layers[-1].output] + # update self._inbound_nodes + self._inbound_nodes[0].output_tensors = self.outputs + self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] + self.built = False + + def get_layer(self, name=None, index=None): + """Retrieve a layer that is part of the model. + + Returns a layer based on either its name (unique) + or its index in the graph. Indices are based on + order of horizontal graph traversal (bottom-up). + + Arguments: + name: string, name of layer. + index: integer, index of layer. + + Returns: + A layer instance. + """ + if not self.built: + self.build() + return self.model.get_layer(name, index) + + def call(self, inputs, **kwargs): + if not self.built: + self.build() + return self.model.call(inputs, **kwargs) + + def build(self, input_shape=None): + if not self.inputs or not self.outputs: + raise TypeError('Sequential model cannot be built: model is empty.' + ' Add some layers first.') + # actually create the model + self.model = Model(self.inputs, self.outputs[0], name=self.name + '_model') + self.model.trainable = self.trainable + + # mirror model attributes + self.supports_masking = self.model.supports_masking + self._output_mask_cache = self.model._output_mask_cache + self._output_tensor_cache = self.model._output_tensor_cache + self._output_shape_cache = self.model._output_shape_cache + self._input_layers = self.model._input_layers + self._output_layers = self.model._output_layers + self._input_coordinates = self.model._input_coordinates + self._output_coordinates = self.model._output_coordinates + self._nodes_by_depth = self.model._nodes_by_depth + self._network_nodes = self.model._network_nodes + self.output_names = self.model.output_names + self.input_names = self.model.input_names + self._feed_input_names = self.model._feed_input_names + self._feed_inputs = self.model._feed_inputs + + # Make sure child model callbacks + # will call the parent Sequential model. + self.model.callback_model = self + + self.built = True + + @property + def uses_learning_phase(self): + if not self.built: + self.build() + return self.model.uses_learning_phase + + def _gather_list_attr(self, attr): + all_attrs = [] + for layer in self.layers: + all_attrs += getattr(layer, attr, []) + return all_attrs + + def _make_train_function(self): + self.model._make_train_function() + + def _make_test_function(self): + self.model._make_test_function() + + def _make_predict_function(self): + self.model._make_predict_function() + + @property + def trainable(self): + return self._trainable + + @trainable.setter + def trainable(self, value): + if self.model: + self.model.trainable = value + self._trainable = value + + @property + def trainable_weights(self): + if not self.trainable: + return [] + return self._gather_list_attr('trainable_weights') + + @property + def non_trainable_weights(self): + weights = self._gather_list_attr('non_trainable_weights') + if not self.trainable: + trainable_weights = self._gather_list_attr('trainable_weights') + return trainable_weights + weights + return weights + + @property + def regularizers(self): + if not self.built: + self.build() + return self.model.regularizers + + def get_weights(self): + """Retrieves the weights of the model. + + Returns: + A flat list of Numpy arrays + (one array per model weight). + """ + if not self.built: + self.build() + return self.model.get_weights() + + def set_weights(self, weights): + """Sets the weights of the model. + + Arguments: + weights: Should be a list + of Numpy arrays with shapes and types matching + the output of `model.get_weights()`. + """ + if not self.built: + self.build() + self.model.set_weights(weights) + + def load_weights(self, filepath, by_name=False): + if h5py is None: + raise ImportError('`load_weights` requires h5py.') + f = h5py.File(filepath, mode='r') + if 'layer_names' not in f.attrs and 'model_weights' in f: + f = f['model_weights'] + layers = self.layers + if by_name: + saving.load_weights_from_hdf5_group_by_name(f, layers) + else: + saving.load_weights_from_hdf5_group(f, layers) + if hasattr(f, 'close'): + f.close() + + def save_weights(self, filepath, overwrite=True): + if h5py is None: + raise ImportError('`save_weights` requires h5py.') + # If file exists and should not be overwritten: + if not overwrite and os.path.isfile(filepath): + proceed = ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + layers = self.layers + f = h5py.File(filepath, 'w') + saving.save_weights_to_hdf5_group(f, layers) + f.flush() + f.close() + + def compile(self, + optimizer, + loss, + metrics=None, + sample_weight_mode=None, + weighted_metrics=None, + target_tensors=None, + **kwargs): + """Configures the model for training. + + Arguments: + optimizer: String (name of optimizer) or optimizer object. + See [optimizers](/optimizers). + loss: String (name of objective function) or objective function. + See [losses](/losses). + If the model has multiple outputs, you can use a different loss + on each output by passing a dictionary or a list of losses. + The loss value that will be minimized by the model + will then be the sum of all individual losses. + metrics: List of metrics to be evaluated by the model + during training and testing. + Typically you will use `metrics=['accuracy']`. + To specify different metrics for different outputs of a + multi-output model, you could also pass a dictionary, + such as `metrics={'output_a': 'accuracy'}`. + sample_weight_mode: If you need to do timestep-wise + sample weighting (2D weights), set this to `"temporal"`. + `None` defaults to sample-wise weights (1D). + If the model has multiple outputs, you can use a different + `sample_weight_mode` on each output by passing a + dictionary or a list of modes. + weighted_metrics: list of metrics to be evaluated and weighted + by `sample_weight` or `class_weight` during training and testing. + target_tensors: By default, Keras will create a placeholder for the + model's target, which will be fed with the target data during + training. If instead you would like to use your own + target tensor (in turn, Keras will not expect external + Numpy data for these targets at training time), you + can specify them via the `target_tensors` argument. + It should be a single tensor + (for a single-output `Sequential` model). + **kwargs: These arguments are passed into `tf.Session.run`. + + Example: + ```python + model = Sequential() + model.add(Dense(32, input_shape=(500,))) + model.add(Dense(10, activation='softmax')) + model.compile(optimizer='rmsprop', + loss='categorical_crossentropy', + metrics=['accuracy']) + ``` + """ + # create the underlying model + self.build() + # call compile method of Model class + self.model.compile( + optimizer, + loss, + metrics=metrics, + sample_weight_mode=sample_weight_mode, + weighted_metrics=weighted_metrics, + target_tensors=target_tensors, + **kwargs) + self.optimizer = self.model.optimizer + self.loss = self.model.loss + self.metrics = self.model.metrics + self.loss_weights = self.model.loss_weights + self.sample_weight_mode = self.model.sample_weight_mode + self.weighted_metrics = self.model.weighted_metrics + self.targets = self.model.targets + self.metrics_tensors = self.model.metrics_tensors + self.metrics_names = self.model.metrics_names + self.sample_weights = self.model.sample_weights + self.total_loss = self.model.total_loss + + def fit(self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose=1, + callbacks=None, + validation_split=0., + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + **kwargs): + """Trains the model for a fixed number of epochs. + + Arguments: + x: Numpy array of training data. + If the input layer in the model is named, you can also pass a + dictionary mapping the input name to a Numpy array. + `x` can be `None` (default) if feeding from + TensorFlow data tensors. + y: Numpy array of target (label) data. + If the output layer in the model is named, you can also pass a + dictionary mapping the output name to a Numpy array. + `y` can be `None` (default) if feeding from + TensorFlow data tensors. + batch_size: Integer or `None`. + Number of samples per gradient update. + If unspecified, it will default to 32. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided. + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. + verbose: 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during training. + See [callbacks](/callbacks). + validation_split: Float between 0 and 1: + Fraction of the training data to be used as validation data. + The model will set apart this fraction of the training data, + will not train on it, and will evaluate + the loss and any model metrics + on this data at the end of each epoch. + The validation data is selected from the last samples + in the `x` and `y` data provided, before shuffling. + validation_data: tuple `(x_val, y_val)` or tuple + `(x_val, y_val, val_sample_weights)` on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. + This will override `validation_split`. + shuffle: Boolean (whether to shuffle the training data + before each epoch) or str (for 'batch'). + 'batch' is a special option for dealing with the + limitations of HDF5 data; it shuffles in batch-sized chunks. + Has no effect when `steps_per_epoch` is not `None`. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. + sample_weight: Optional Numpy array of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + Numpy array with the same length as the input samples + (1:1 mapping between weights and samples), + or in the case of temporal data, + you can pass a 2D array with shape + `(samples, sequence_length)`, + to apply a different weight to every timestep of every sample. + In this case you should make sure to specify + `sample_weight_mode="temporal"` in `compile()`. + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run). + steps_per_epoch: Total number of steps (batches of samples) + before declaring one epoch finished and starting the + next epoch. When training with input tensors such as + TensorFlow data tensors, the default `None` is equal to + the number of unique samples in your dataset divided by + the batch size, or 1 if that cannot be determined. + validation_steps: Only relevant if `steps_per_epoch` + is specified. Total number of steps (batches of samples) + to validate before stopping. + **kwargs: Used for backwards compatibility support. + + Returns: + A `History` object. Its `History.history` attribute is + a record of training loss values and metrics values + at successive epochs, as well as validation loss values + and validation metrics values (if applicable). + + Raises: + RuntimeError: If the model was never compiled. + ValueError: In case of mismatch between the provided input data + and what the model expects. + """ + if not self.built: + raise RuntimeError('The model needs to be compiled before being used.') + return self.model.fit( + x, + y, + batch_size=batch_size, + epochs=epochs, + verbose=verbose, + callbacks=callbacks, + validation_split=validation_split, + validation_data=validation_data, + shuffle=shuffle, + class_weight=class_weight, + sample_weight=sample_weight, + initial_epoch=initial_epoch, + steps_per_epoch=steps_per_epoch, + validation_steps=validation_steps) + + def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): + """Computes the loss on some input data, batch by batch. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + y: labels, as a Numpy array. + batch_size: integer. Number of samples per gradient update. + verbose: verbosity mode, 0 or 1. + sample_weight: sample weights, as a Numpy array. + + Returns: + Scalar test loss (if the model has no metrics) + or list of scalars (if the model computes other metrics). + The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + + Raises: + RuntimeError: if the model was never compiled. + """ + if not self.built: + raise RuntimeError('The model needs to be compiled before being used.') + return self.model.evaluate( + x, + y, + batch_size=batch_size, + verbose=verbose, + sample_weight=sample_weight) + + def predict(self, x, batch_size=32, verbose=0): + """Generates output predictions for the input samples. + + The input samples are processed batch by batch. + + Arguments: + x: the input data, as a Numpy array. + batch_size: integer. + verbose: verbosity mode, 0 or 1. + + Returns: + A Numpy array of predictions. + """ + if not self.built: + self.build() + return self.model.predict(x, batch_size=batch_size, verbose=verbose) + + def predict_on_batch(self, x): + """Returns predictions for a single batch of samples. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + + Returns: + A Numpy array of predictions. + """ + if not self.built: + self.build() + return self.model.predict_on_batch(x) + + def train_on_batch(self, x, y, class_weight=None, sample_weight=None): + """Single gradient update over one batch of samples. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + y: labels, as a Numpy array. + class_weight: dictionary mapping classes to a weight value, + used for scaling the loss function (during training only). + sample_weight: sample weights, as a Numpy array. + + Returns: + Scalar training loss (if the model has no metrics) + or list of scalars (if the model computes other metrics). + The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + + Raises: + RuntimeError: if the model was never compiled. + """ + if not self.built: + raise RuntimeError('The model needs to be compiled before being used.') + return self.model.train_on_batch( + x, y, sample_weight=sample_weight, class_weight=class_weight) + + def test_on_batch(self, x, y, sample_weight=None): + """Evaluates the model over a single batch of samples. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + y: labels, as a Numpy array. + sample_weight: sample weights, as a Numpy array. + + Returns: + Scalar test loss (if the model has no metrics) + or list of scalars (if the model computes other metrics). + The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + + Raises: + RuntimeError: if the model was never compiled. + """ + if not self.built: + raise RuntimeError('The model needs to be compiled before being used.') + return self.model.test_on_batch(x, y, sample_weight=sample_weight) + + def predict_proba(self, x, batch_size=32, verbose=0): + """Generates class probability predictions for the input samples. + + The input samples are processed batch by batch. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + batch_size: integer. + verbose: verbosity mode, 0 or 1. + + Returns: + A Numpy array of probability predictions. + """ + preds = self.predict(x, batch_size, verbose) + if preds.min() < 0. or preds.max() > 1.: + logging.warning('Network returning invalid probability values. ' + 'The last layer might not normalize predictions ' + 'into probabilities ' + '(like softmax or sigmoid would).') + return preds + + def predict_classes(self, x, batch_size=32, verbose=0): + """Generate class predictions for the input samples. + + The input samples are processed batch by batch. + + Arguments: + x: input data, as a Numpy array or list of Numpy arrays + (if the model has multiple inputs). + batch_size: integer. + verbose: verbosity mode, 0 or 1. + + Returns: + A numpy array of class predictions. + """ + proba = self.predict(x, batch_size=batch_size, verbose=verbose) + if proba.shape[-1] > 1: + return proba.argmax(axis=-1) + else: + return (proba > 0.5).astype('int32') + + def fit_generator(self, + generator, + steps_per_epoch=None, + epochs=1, + verbose=1, + callbacks=None, + validation_data=None, + validation_steps=None, + class_weight=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + shuffle=True, + initial_epoch=0, + **kwargs): + """Fits the model on data generated batch-by-batch by a Python generator. + + The generator is run in parallel to the model, for efficiency. + For instance, this allows you to do real-time data augmentation + on images on CPU in parallel to training your model on GPU. + + Arguments: + generator: A generator. + The output of the generator must be either + - a tuple (inputs, targets) + - a tuple (inputs, targets, sample_weights). + All arrays should contain the same number of samples. + The generator is expected to loop over its data + indefinitely. An epoch finishes when `steps_per_epoch` + batches have been seen by the model. + steps_per_epoch: Total number of steps (batches of samples) + to yield from `generator` before declaring one epoch + finished and starting the next epoch. It should typically + be equal to the number of samples of your dataset + divided by the batch size. + Optional for `Sequence`: if unspecified, will use + the `len(generator)` as a number of steps. + epochs: Integer, total number of iterations on the data. + Note that in conjunction with initial_epoch, the parameter + epochs is to be understood as "final epoch". The model is + not trained for n steps given by epochs, but until the + epoch epochs is reached. + verbose: Verbosity mode, 0, 1, or 2. + callbacks: List of callbacks to be called during training. + validation_data: This can be either + - A generator for the validation data + - A tuple (inputs, targets) + - A tuple (inputs, targets, sample_weights). + validation_steps: Only relevant if `validation_data` + is a generator. + Number of steps to yield from validation generator + at the end of every epoch. It should typically + be equal to the number of samples of your + validation dataset divided by the batch size. + Optional for `Sequence`: if unspecified, will use + the `len(validation_data)` as a number of steps. + class_weight: Dictionary mapping class indices to a weight + for the class. + max_queue_size: Maximum size for the generator queue + workers: Maximum number of processes to spin up + use_multiprocessing: If True, use process based threading. + Note that because + this implementation relies on multiprocessing, + you should not pass + non picklable arguments to the generator + as they can't be passed + easily to children processes. + shuffle: Whether to shuffle the order of the batches at + the beginning of each epoch. Only used with instances + of `Sequence` (keras.utils.Sequence). + initial_epoch: Epoch at which to start training + (useful for resuming a previous training run) + **kwargs: support for legacy arguments. + + Returns: + A `History` object. + + Raises: + RuntimeError: if the model was never compiled. + ValueError: In case the generator yields + data in an invalid format. + + Example: + + ```python + def generate_arrays_from_file(path): + while 1: + f = open(path) + for line in f: + # create Numpy arrays of input data + # and labels, from each line in the file + x, y = process_line(line) + yield (x, y) + f.close() + + model.fit_generator(generate_arrays_from_file('/my_file.txt'), + steps_per_epoch=1000, epochs=10) + ``` + """ + # Legacy support + if 'max_q_size' in kwargs: + max_queue_size = kwargs.pop('max_q_size') + logging.warning('The argument `max_q_size` has been renamed ' + '`max_queue_size`. Update your method calls accordingly.') + if 'pickle_safe' in kwargs: + use_multiprocessing = kwargs.pop('pickle_safe') + logging.warning('The argument `pickle_safe` has been renamed ' + '`use_multiprocessing`. ' + 'Update your method calls accordingly.') + if kwargs: + raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) + + if not self.built: + raise RuntimeError('The model needs to be compiled before being used.') + return self.model.fit_generator( + generator, + steps_per_epoch, + epochs, + verbose=verbose, + callbacks=callbacks, + validation_data=validation_data, + validation_steps=validation_steps, + class_weight=class_weight, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + shuffle=shuffle, + initial_epoch=initial_epoch) + + def evaluate_generator(self, + generator, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + **kwargs): + """Evaluates the model on a data generator. + + The generator should return the same kind of data + as accepted by `test_on_batch`. + + Arguments: + generator: Generator yielding tuples (inputs, targets) + or (inputs, targets, sample_weights) + steps: Total number of steps (batches of samples) + to yield from `generator` before stopping. + Optional for `Sequence`: if unspecified, will use + the `len(generator)` as a number of steps. + max_queue_size: maximum size for the generator queue + workers: maximum number of processes to spin up + use_multiprocessing: if True, use process based threading. + Note that because this implementation + relies on multiprocessing, you should not pass + non picklable arguments to the generator + as they can't be passed easily to children processes. + **kwargs: support for legacy arguments. + + Returns: + Scalar test loss (if the model has no metrics) + or list of scalars (if the model computes other metrics). + The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + + Raises: + RuntimeError: if the model was never compiled. + ValueError: In case the generator yields + data in an invalid format. + """ + # Legacy support + if 'max_q_size' in kwargs: + max_queue_size = kwargs.pop('max_q_size') + logging.warning('The argument `max_q_size` has been renamed ' + '`max_queue_size`. Update your method calls accordingly.') + if 'pickle_safe' in kwargs: + use_multiprocessing = kwargs.pop('pickle_safe') + logging.warning('The argument `pickle_safe` has been renamed ' + '`use_multiprocessing`. ' + 'Update your method calls accordingly.') + if kwargs: + raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) + + if not self.built: + raise RuntimeError('The model needs to be compiled before being used.') + return self.model.evaluate_generator( + generator, + steps, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing) + + def predict_generator(self, + generator, + steps=None, + max_queue_size=10, + workers=1, + use_multiprocessing=False, + verbose=0, + **kwargs): + """Generates predictions for the input samples from a data generator. + + The generator should return the same kind of data as accepted by + `predict_on_batch`. + + Arguments: + generator: generator yielding batches of input samples. + steps: Total number of steps (batches of samples) + to yield from `generator` before stopping. + Optional for `Sequence`: if unspecified, will use + the `len(generator)` as a number of steps. + max_queue_size: maximum size for the generator queue + workers: maximum number of processes to spin up + use_multiprocessing: if True, use process based threading. + Note that because this implementation + relies on multiprocessing, you should not pass + non picklable arguments to the generator + as they can't be passed easily to children processes. + verbose: verbosity mode, 0 or 1. + **kwargs: support for legacy arguments. + + Returns: + A Numpy array of predictions. + + Raises: + ValueError: In case the generator yields + data in an invalid format. + """ + # Legacy support + if 'max_q_size' in kwargs: + max_queue_size = kwargs.pop('max_q_size') + logging.warning('The argument `max_q_size` has been renamed ' + '`max_queue_size`. Update your method calls accordingly.') + if 'pickle_safe' in kwargs: + use_multiprocessing = kwargs.pop('pickle_safe') + logging.warning('The argument `pickle_safe` has been renamed ' + '`use_multiprocessing`. ' + 'Update your method calls accordingly.') + if kwargs: + raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) + + if not self.built: + self.build() + return self.model.predict_generator( + generator, + steps, + max_queue_size=max_queue_size, + workers=workers, + use_multiprocessing=use_multiprocessing, + verbose=verbose) + + def get_config(self): + config = [] + for layer in self.layers: + config.append({ + 'class_name': layer.__class__.__name__, + 'config': layer.get_config() + }) + return copy.deepcopy(config) + + @classmethod + def from_config(cls, config, custom_objects=None): + model = cls() + for conf in config: + layer = layer_module.deserialize(conf, custom_objects=custom_objects) + model.add(layer) + return model diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential_test.py b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py new file mode 100644 index 0000000000000000000000000000000000000000..166634bd8219b831ce212ba983a4ab695b00c3b7 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/engine/sequential_test.py @@ -0,0 +1,152 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests specific to `Sequential` model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.keras._impl import keras +from tensorflow.python.platform import test + + +class TestSequential(test.TestCase): + """Most Sequential model API tests are covered in `training_test.py`. + """ + + def test_basic_methods(self): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_dim=2)) + model.add(keras.layers.Dropout(0.3, name='dp')) + model.add(keras.layers.Dense(2, kernel_regularizer='l2', + kernel_constraint='max_norm')) + model.build() + self.assertEqual(model.state_updates, model.model.state_updates) + self.assertEqual(model.get_layer(name='dp').name, 'dp') + + def test_sequential_pop(self): + num_hidden = 5 + input_dim = 3 + batch_size = 5 + num_classes = 2 + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.add(keras.layers.Dense(num_classes)) + model.compile(loss='mse', optimizer='sgd') + x = np.random.random((batch_size, input_dim)) + y = np.random.random((batch_size, num_classes)) + model.fit(x, y, epochs=1) + model.pop() + self.assertEqual(len(model.layers), 1) + self.assertEqual(model.output_shape, (None, num_hidden)) + model.compile(loss='mse', optimizer='sgd') + y = np.random.random((batch_size, num_hidden)) + model.fit(x, y, epochs=1) + + # Test popping single-layer model + model = keras.models.Sequential() + model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) + model.pop() + self.assertEqual(len(model.layers), 0) + self.assertEqual(len(model.outputs), 0) + + # Invalid use case + model = keras.models.Sequential() + with self.assertRaises(TypeError): + model.pop() + + def test_invalid_use_cases(self): + with self.test_session(): + # Added objects must be layer instances + with self.assertRaises(TypeError): + model = keras.models.Sequential() + model.add(None) + + # Added layers must have an inputs shape + with self.assertRaises(ValueError): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1)) + + # Added layers cannot have multiple outputs + class MyLayer(keras.layers.Layer): + + def call(self, inputs): + return [3 * inputs, 2 * inputs] + + def compute_output_shape(self, input_shape): + return [input_shape, input_shape] + + with self.assertRaises(ValueError): + model = keras.models.Sequential() + model.add(MyLayer(input_shape=(3,))) + with self.assertRaises(TypeError): + model = keras.models.Sequential() + model.add(keras.layers.Dense(1, input_dim=1)) + model.add(MyLayer()) + + # Building empty model + model = keras.models.Sequential() + with self.assertRaises(TypeError): + model.build() + + def test_nested_sequential_trainability(self): + input_dim = 20 + num_units = 10 + num_classes = 2 + + inner_model = keras.models.Sequential() + inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) + + model = keras.models.Sequential() + model.add(inner_model) + model.add(keras.layers.Dense(num_classes)) + + self.assertEqual(len(model.trainable_weights), 4) + inner_model.trainable = False + self.assertEqual(len(model.trainable_weights), 2) + inner_model.trainable = True + self.assertEqual(len(model.trainable_weights), 4) + + def test_sequential_update_disabling(self): + val_a = np.random.random((10, 4)) + val_out = np.random.random((10, 4)) + + with self.test_session(): + model = keras.models.Sequential() + model.add(keras.layers.BatchNormalization(input_shape=(4,))) + + model.trainable = False + assert not model.updates + + model.compile('sgd', 'mse') + assert not model.updates + assert not model.model.updates + + x1 = model.predict(val_a) + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + self.assertAllClose(x1, x2, atol=1e-7) + + model.trainable = True + model.compile('sgd', 'mse') + assert model.updates + assert model.model.updates + + model.train_on_batch(val_a, val_out) + x2 = model.predict(val_a) + assert np.abs(np.sum(x1 - x2)) > 1e-5 diff --git a/tensorflow/python/keras/_impl/keras/engine/topology.py b/tensorflow/python/keras/_impl/keras/engine/topology.py deleted file mode 100644 index d1c1d2c8c41ba5c3192827fdf0407b6da01e82bd..0000000000000000000000000000000000000000 --- a/tensorflow/python/keras/_impl/keras/engine/topology.py +++ /dev/null @@ -1,1607 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -# pylint: disable=protected-access -"""Base layer code and base model (Network) code. -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import copy -import json -import os - -import numpy as np -from six.moves import zip # pylint: disable=redefined-builtin - -from tensorflow.python.eager import context -from tensorflow.python.framework import tensor_shape -from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import constraints -from tensorflow.python.keras._impl.keras import initializers -from tensorflow.python.keras._impl.keras import regularizers -from tensorflow.python.keras._impl.keras.utils import conv_utils -from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary -from tensorflow.python.layers import base as tf_base_layers -from tensorflow.python.layers import network as tf_network -from tensorflow.python.layers import utils as tf_layers_util -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util.tf_export import tf_export - - -# pylint: disable=g-import-not-at-top -try: - import h5py -except ImportError: - h5py = None - -try: - import yaml -except ImportError: - yaml = None -# pylint: enable=g-import-not-at-top - -# pylint: disable=invalid-name -InputSpec = tf_base_layers.InputSpec -Node = tf_base_layers.Node -TFBaseLayer = tf_base_layers.Layer -# pylint: enable=invalid-name - - -@tf_export('keras.layers.Layer') -class Layer(tf_base_layers.Layer): - """Abstract base layer class. - - # Properties - name: String, must be unique within a model. - input_spec: List of InputSpec class instances - each entry describes one required input: - - ndim - - dtype - A layer with `n` input tensors must have - an `input_spec` of length `n`. - trainable: Boolean, whether the layer weights - will be updated during training. - uses_learning_phase: Whether any operation - of the layer uses `K.in_training_phase()` - or `K.in_test_phase()`. - input_shape: Shape tuple. Provided for convenience, - but note that there may be cases in which this - attribute is ill-defined (e.g. a shared layer - with multiple input shapes), in which case - requesting `input_shape` will raise an Exception. - Prefer using `layer.get_input_shape_for(input_shape)`, - or `layer.get_input_shape_at(node_index)`. - output_shape: Shape tuple. See above. - inbound_nodes: List of nodes. - outbound_nodes: List of nodes. - input, output: Input/output tensor(s). Note that if the layer is used - more than once (shared layer), this is ill-defined - and will raise an exception. In such cases, use - `layer.get_input_at(node_index)`. - input_mask, output_mask: Same as above, for masks. - trainable_weights: List of variables. - non_trainable_weights: List of variables. - weights: The concatenation of the lists trainable_weights and - non_trainable_weights (in this order). - - # Methods - call(x, mask=None): Where the layer's logic lives. - __call__(x, mask=None): Wrapper around the layer logic (`call`). - If x is a Keras tensor: - - Connect current layer with last layer from tensor: - `self._add_inbound_node(last_layer)` - - Add layer to tensor history - If layer is not built: - - Build from inputs shape - get_weights() - set_weights(weights) - get_config() - count_params() - compute_output_shape(input_shape) - compute_mask(x, mask) - get_input_at(node_index) - get_output_at(node_index) - get_input_shape_at(node_index) - get_output_shape_at(node_index) - get_input_mask_at(node_index) - get_output_mask_at(node_index) - - # Class Methods - from_config(config) - - # Internal methods: - build(input_shape) - _add_inbound_node(layer, index=0) - """ - - def __init__(self, **kwargs): - # These properties should be set by the user via keyword arguments. - # note that 'dtype', 'input_shape' and 'batch_input_shape' - # are only applicable to input layers: do not pass these keywords - # to non-input layers. - allowed_kwargs = { - 'activity_regularizer', - 'input_shape', - 'batch_input_shape', - 'batch_size', - 'dtype', - 'name', - 'trainable', - 'weights', - } - # Validate optional keyword arguments. - for kwarg in kwargs: - if kwarg not in allowed_kwargs: - raise TypeError('Keyword argument not understood:', kwarg) - - # Get layer name. - name = kwargs.get('name') - - # Get `trainable` status. - trainable = kwargs.get('trainable', True) - - # Get `dtype`. - dtype = kwargs.get('dtype') - if dtype is None: - dtype = K.floatx() - - # Call super, which will set all properties common to Keras layers - # and core TF layers. - super(Layer, self).__init__( - name=name, dtype=dtype, trainable=trainable, - activity_regularizer=kwargs.get('activity_regularizer')) - - # Add properties that are Keras-only for now. - self.supports_masking = False - - # Manage input shape information if passed. - if 'input_shape' in kwargs or 'batch_input_shape' in kwargs: - # In this case we will later create an input layer - # to insert before the current layer - if 'batch_input_shape' in kwargs: - batch_input_shape = tuple(kwargs['batch_input_shape']) - elif 'input_shape' in kwargs: - if 'batch_size' in kwargs: - batch_size = kwargs['batch_size'] - else: - batch_size = None - batch_input_shape = (batch_size,) + tuple(kwargs['input_shape']) - self._batch_input_shape = batch_input_shape - - # Manage initial weight values if passed. - if 'weights' in kwargs: - self._initial_weights = kwargs['weights'] - else: - self._initial_weights = None - - def add_weight(self, - name, - shape, - dtype=None, - initializer=None, - regularizer=None, - trainable=True, - constraint=None): - """Adds a weight variable to the layer. - - Arguments: - name: String, the name for the weight variable. - shape: The shape tuple of the weight. - dtype: The dtype of the weight. - initializer: An Initializer instance (callable). - regularizer: An optional Regularizer instance. - trainable: A boolean, whether the weight should - be trained via backprop or not (assuming - that the layer itself is also trainable). - constraint: An optional Constraint instance. - - Returns: - The created weight variable. - """ - if dtype is None: - dtype = K.floatx() - weight = self.add_variable(name, shape, - dtype=dtype, - initializer=initializers.get(initializer), - regularizer=regularizers.get(regularizer), - constraint=constraints.get(constraint), - trainable=trainable) - return weight - - def call(self, inputs, **kwargs): # pylint: disable=unused-argument - """This is where the layer's logic lives. - - Arguments: - inputs: Input tensor, or list/tuple of input tensors. - **kwargs: Additional keyword arguments. - - Returns: - A tensor or list/tuple of tensors. - """ - return inputs - - def __call__(self, inputs, **kwargs): - """Wrapper around self.call(), for handling internal references. - - If a Keras tensor is passed: - - We call self._add_inbound_node(). - - If necessary, we `build` the layer to match - the shape of the input(s). - - We update the _keras_history of the output tensor(s) - with the current layer. - This is done as part of _add_inbound_node(). - - Arguments: - inputs: Can be a tensor or list/tuple of tensors. - **kwargs: Additional keyword arguments to be passed to `call()`. - - Returns: - Output of the layer's `call` method. - - Raises: - ValueError: in case the layer is missing shape information - for its `build` call. - """ - # Actually call the layer (optionally building it). - output = super(Layer, self).__call__(inputs, **kwargs) - if context.in_eager_mode(): - return output - - # Update learning phase info. - output_tensors = _to_list(output) - uses_lp = any( - [getattr(x, '_uses_learning_phase', False) for x in _to_list(inputs)]) - uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp - for i in range(len(output_tensors)): - output_tensors[i]._uses_learning_phase = getattr( - output_tensors[i], '_uses_learning_phase', False) or uses_lp - - # Optionally load weight values that were specified at layer instantiation. - if hasattr(self, '_initial_weights') and self._initial_weights is not None: - self.set_weights(self._initial_weights) - del self._initial_weights - return output - - def compute_output_shape(self, input_shape): - """Computes the output shape of the layer. - - Assumes that the layer will be built - to match that input shape provided. - - Arguments: - input_shape: Shape tuple (tuple of integers) - or list of shape tuples (one per output tensor of the layer). - Shape tuples can include None for free dimensions, - instead of an integer. - - Returns: - An input shape tuple. - """ - logging.warning( - 'All custom layers should implement the ' - '`compute_output_shape` method. This layer (' + self.name + ') ' - 'is relying on the base `Layer.compute_output_shape` implementation, ' - 'which will start raising a `NotImplementedError` ' - 'as of July 1st, 2018.') - return input_shape - - def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument - """Computes an output mask tensor. - - Arguments: - inputs: Tensor or list of tensors. - mask: Tensor or list of tensors. - - Returns: - None or a tensor (or list of tensors, - one per output tensor of the layer). - """ - if not self.supports_masking: - if mask is not None: - if isinstance(mask, list): - if any(m is not None for m in mask): - raise TypeError('Layer ' + self.name + ' does not support masking, ' - 'but was passed an input_mask: ' + str(mask)) - else: - raise TypeError('Layer ' + self.name + ' does not support masking, ' - 'but was passed an input_mask: ' + str(mask)) - # masking not explicitly supported: return None as mask - return None - # if masking is explicitly supported, by default - # carry over the input mask - return mask - - def get_input_mask_at(self, node_index): - """Retrieves the input mask tensor(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A mask tensor - (or list of tensors if the layer has multiple inputs). - """ - inputs = self.get_input_at(node_index) - if isinstance(inputs, list): - return [getattr(x, '_keras_mask', None) for x in inputs] - else: - return getattr(inputs, '_keras_mask', None) - - def get_output_mask_at(self, node_index): - """Retrieves the output mask tensor(s) of a layer at a given node. - - Arguments: - node_index: Integer, index of the node - from which to retrieve the attribute. - E.g. `node_index=0` will correspond to the - first time the layer was called. - - Returns: - A mask tensor - (or list of tensors if the layer has multiple outputs). - """ - output = self.get_output_at(node_index) - if isinstance(output, list): - return [getattr(x, '_keras_mask', None) for x in output] - else: - return getattr(output, '_keras_mask', None) - - @property - def input_mask(self): - """Retrieves the input mask tensor(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Input mask tensor (potentially None) or list of input - mask tensors. - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - inputs = self.input - if isinstance(inputs, list): - return [getattr(x, '_keras_mask', None) for x in inputs] - else: - return getattr(inputs, '_keras_mask', None) - - @property - def output_mask(self): - """Retrieves the output mask tensor(s) of a layer. - - Only applicable if the layer has exactly one inbound node, - i.e. if it is connected to one incoming layer. - - Returns: - Output mask tensor (potentially None) or list of output - mask tensors. - - Raises: - AttributeError: if the layer is connected to - more than one incoming layers. - """ - output = self.output - if isinstance(output, list): - return [getattr(x, '_keras_mask', None) for x in output] - else: - return getattr(output, '_keras_mask', None) - - def set_weights(self, weights): - """Sets the weights of the layer, from Numpy arrays. - - Arguments: - weights: a list of Numpy arrays. The number - of arrays and their shape must match - number of the dimensions of the weights - of the layer (i.e. it should match the - output of `get_weights`). - - Raises: - ValueError: If the provided weights list does not match the - layer's specifications. - """ - params = self.weights - if len(params) != len(weights): - raise ValueError('You called `set_weights(weights)` on layer "' + - self.name + '" with a weight list of length ' + - str(len(weights)) + ', but the layer was expecting ' + - str(len(params)) + ' weights. Provided weights: ' + - str(weights)[:50] + '...') - if not params: - return - weight_value_tuples = [] - param_values = K.batch_get_value(params) - for pv, p, w in zip(param_values, params, weights): - if pv.shape != w.shape: - raise ValueError('Layer weight shape ' + str(pv.shape) + - ' not compatible with ' - 'provided weight shape ' + str(w.shape)) - weight_value_tuples.append((p, w)) - K.batch_set_value(weight_value_tuples) - - def get_weights(self): - """Returns the current weights of the layer. - - Returns: - Weights values as a list of numpy arrays. - """ - params = self.weights - return K.batch_get_value(params) - - def get_config(self): - """Returns the config of the layer. - - A layer config is a Python dictionary (serializable) - containing the configuration of a layer. - The same layer can be reinstantiated later - (without its trained weights) from this configuration. - - The config of a layer does not include connectivity - information, nor the layer class name. These are handled - by `Network` (one layer of abstraction above). - - Returns: - Python dictionary. - """ - config = {'name': self.name, 'trainable': self.trainable} - if hasattr(self, '_batch_input_shape'): - config['batch_input_shape'] = self._batch_input_shape - if hasattr(self, 'dtype'): - config['dtype'] = self.dtype - return config - - @classmethod - def from_config(cls, config): - """Creates a layer from its config. - - This method is the reverse of `get_config`, - capable of instantiating the same layer from the config - dictionary. It does not handle layer connectivity - (handled by Network), nor weights (handled by `set_weights`). - - Arguments: - config: A Python dictionary, typically the - output of get_config. - - Returns: - A layer instance. - """ - return cls(**config) - - @tf_base_layers.Layer.activity_regularizer.setter - def activity_regularizer(self, activity_regularizer): - self._activity_regularizer = activity_regularizer - - -@tf_export('keras.layers.InputLayer') -class InputLayer(tf_network.InputLayer, Layer): - """Layer to be used as an entry point into a graph. - - It can either wrap an existing tensor (pass an `input_tensor` argument) - or create its a placeholder tensor (pass argument `input_shape`. - - Arguments: - input_shape: Shape tuple, not including the batch axis. - batch_size: Optional input batch size (integer or None). - dtype: Datatype of the input. - input_tensor: Optional tensor to use as layer input - instead of creating a placeholder. - sparse: Boolean, whether the placeholder created - is meant to be sparse. - name: Name of the layer (string). - """ - - def __init__(self, - input_shape=None, - batch_size=None, - dtype=None, - input_tensor=None, - sparse=False, - name=None, - **kwargs): - if 'batch_input_shape' in kwargs: - batch_input_shape = kwargs.pop('batch_input_shape') - if input_shape and batch_input_shape: - raise ValueError('Only provide the input_shape OR ' - 'batch_input_shape argument to ' - 'InputLayer, not both at the same time.') - batch_size = batch_input_shape[0] - input_shape = batch_input_shape[1:] - if kwargs: - raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) - - if not name: - prefix = 'input' - name = prefix + '_' + str(K.get_uid(prefix)) - - if not dtype: - if input_tensor is None: - dtype = K.floatx() - else: - dtype = K.dtype(input_tensor) - super(InputLayer, self).__init__(input_shape=input_shape, - batch_size=batch_size, - dtype=dtype, - input_tensor=input_tensor, - sparse=sparse, - name=name) - - def get_config(self): - config = { - 'batch_input_shape': self._batch_input_shape, - 'dtype': self.dtype, - 'sparse': self.sparse, - 'name': self.name - } - return config - - -@tf_export('keras.layers.Input', 'keras.Input') -def Input( # pylint: disable=invalid-name - shape=None, - batch_size=None, - name=None, - dtype=None, - sparse=False, - tensor=None, - **kwargs): - """`Input()` is used to instantiate a Keras tensor. - - A Keras tensor is a tensor object from the underlying backend - (Theano or TensorFlow), which we augment with certain - attributes that allow us to build a Keras model - just by knowing the inputs and outputs of the model. - - For instance, if a, b and c are Keras tensors, - it becomes possible to do: - `model = Model(input=[a, b], output=c)` - - The added Keras attribute is: - `_keras_history`: Last layer applied to the tensor. - the entire layer graph is retrievable from that layer, - recursively. - - Arguments: - shape: A shape tuple (integers), not including the batch size. - For instance, `shape=(32,)` indicates that the expected input - will be batches of 32-dimensional vectors. - batch_size: optional static batch size (integer). - name: An optional name string for the layer. - Should be unique in a model (do not reuse the same name twice). - It will be autogenerated if it isn't provided. - dtype: The data type expected by the input, as a string - (`float32`, `float64`, `int32`...) - sparse: A boolean specifying whether the placeholder - to be created is sparse. - tensor: Optional existing tensor to wrap into the `Input` layer. - If set, the layer will not create a placeholder tensor. - **kwargs: deprecated arguments support. - - Returns: - A tensor. - - Example: - - ```python - # this is a logistic regression in Keras - x = Input(shape=(32,)) - y = Dense(16, activation='softmax')(x) - model = Model(x, y) - ``` - - Raises: - ValueError: in case of invalid arguments. - """ - if 'batch_shape' in kwargs: - batch_shape = kwargs.pop('batch_shape') - if shape and batch_shape: - raise ValueError('Only provide the shape OR ' - 'batch_shape argument to ' - 'Input, not both at the same time.') - batch_size = batch_shape[0] - shape = batch_shape[1:] - if kwargs: - raise ValueError('Unrecognized keyword arguments:', kwargs.keys()) - - if dtype is None: - dtype = K.floatx() - if not shape and tensor is None: - raise ValueError('Please provide to Input either a `shape`' - ' or a `tensor` argument. Note that ' - '`shape` does not include the batch ' - 'dimension.') - input_layer = InputLayer( - input_shape=shape, - batch_size=batch_size, - name=name, - dtype=dtype, - sparse=sparse, - input_tensor=tensor) - # Return tensor including `_keras_history`. - # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors - if len(outputs) == 1: - return outputs[0] - else: - return outputs - - -class Network(tf_network.GraphNetwork, Layer): - """A Network is a directed acyclic graph of layers. - - It is the topological form of a "model". A Model - is simply a Network with added training routines. - - # Properties - name - inputs - outputs - input_layers - output_layers - input_spec (list of class instances) - each entry describes one required input: - - ndim - - dtype - trainable (boolean) - input_shape - output_shape - inbound_nodes: list of nodes - outbound_nodes: list of nodes - trainable_weights (list of variables) - non_trainable_weights (list of variables) - - # Methods - summary - get_layer - get_weights - set_weights - get_config - compute_output_shape - - # Class Methods - from_config - """ - - def __init__(self, inputs, outputs, name=None): - super(Network, self).__init__(inputs, outputs, name=name) - - self.supports_masking = False - # Fill in the output mask cache. - masks = [] - for x in self.inputs: - mask = x._keras_mask if hasattr(x, '_keras_mask') else None - masks.append(mask) - mask_cache_key = (tf_layers_util.object_list_uid(self.inputs) + '_' + - tf_layers_util.object_list_uid(masks)) - masks = [] - for x in self.outputs: - mask = x._keras_mask if hasattr(x, '_keras_mask') else None - masks.append(mask) - if len(masks) == 1: - mask = masks[0] - else: - mask = masks - self._output_mask_cache[mask_cache_key] = mask - - # Build self.input_names and self.output_names. - self.input_names = [] - self.output_names = [] - self._feed_input_names = [] - self._feed_inputs = [] - self._feed_input_shapes = [] - for i, layer in enumerate(self._input_layers): - self.input_names.append(layer.name) - if layer.is_placeholder: - self._feed_input_names.append(layer.name) - self._feed_input_shapes.append(K.int_shape(self.inputs[i])) - # layer.input gives an error in eager mode - if context.in_graph_mode(): - self._feed_inputs.append(layer.input) - for layer in self._output_layers: - self.output_names.append(layer.name) - - self._internal_input_shapes = [K.int_shape(x) for x in self.inputs] - self._internal_output_shapes = [K.int_shape(x) for x in self.outputs] - - @property - def uses_learning_phase(self): - return any( - [getattr(x, '_uses_learning_phase', False) for x in self.outputs]) - - @property - def stateful(self): - return any([(hasattr(layer, 'stateful') and layer.stateful) - for layer in self.layers]) - - def reset_states(self): - for layer in self.layers: - if hasattr(layer, 'reset_states') and getattr(layer, 'stateful', False): - layer.reset_states() - - @property - def state_updates(self): - """Returns the `updates` from all layers that are stateful. - - This is useful for separating training updates and - state updates, e.g. when we need to update a layer's internal state - during prediction. - - Returns: - A list of update ops. - """ - state_updates = [] - for layer in self.layers: - if getattr(layer, 'stateful', False): - if hasattr(layer, 'updates'): - state_updates += layer.updates - return state_updates - - def get_weights(self): - """Retrieves the weights of the model. - - Returns: - A flat list of Numpy arrays. - """ - weights = [] - for layer in self.layers: - weights += layer.weights - return K.batch_get_value(weights) - - def set_weights(self, weights): - """Sets the weights of the model. - - Arguments: - weights: A list of Numpy arrays with shapes and types matching - the output of `model.get_weights()`. - """ - tuples = [] - for layer in self.layers: - num_param = len(layer.weights) - layer_weights = weights[:num_param] - for sw, w in zip(layer.weights, layer_weights): - tuples.append((sw, w)) - weights = weights[num_param:] - K.batch_set_value(tuples) - - def compute_mask(self, inputs, mask): - inputs = _to_list(inputs) - if mask is None: - masks = [None for _ in range(len(inputs))] - else: - masks = _to_list(mask) - cache_key = (tf_layers_util.object_list_uid(inputs) - + '_' + tf_layers_util.object_list_uid(masks)) - if cache_key in self._output_mask_cache: - return self._output_mask_cache[cache_key] - else: - _, output_masks = self._run_internal_graph(inputs, masks) - return output_masks - - def get_config(self): - config = { - 'name': self.name, - } - node_conversion_map = {} - for layer in self.layers: - if issubclass(layer.__class__, Network): - # Networks start with a pre-existing node - # linking their input to output. - kept_nodes = 1 - else: - kept_nodes = 0 - for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = tf_network._make_node_key(layer.name, - original_node_index) - if node_key in self._network_nodes: - node_conversion_map[node_key] = kept_nodes - kept_nodes += 1 - layer_configs = [] - for layer in self.layers: # From the earliest layers on. - layer_class_name = layer.__class__.__name__ - layer_config = layer.get_config() - filtered_inbound_nodes = [] - for original_node_index, node in enumerate(layer._inbound_nodes): - node_key = tf_network._make_node_key(layer.name, - original_node_index) - if node_key in self._network_nodes: - # The node is relevant to the model: - # add to filtered_inbound_nodes. - if node.arguments: - try: - json.dumps(node.arguments) - kwargs = node.arguments - except TypeError: - logging.warning( - 'Layer ' + layer.name + - ' was passed non-serializable keyword arguments: ' + - str(node.arguments) + '. They will not be included ' - 'in the serialized model (and thus will be missing ' - 'at deserialization time).') - kwargs = {} - else: - kwargs = {} - if node.inbound_layers: - node_data = [] - for i in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[i] - node_index = node.node_indices[i] - tensor_index = node.tensor_indices[i] - node_key = tf_network._make_node_key(inbound_layer.name, - node_index) - new_node_index = node_conversion_map.get(node_key, 0) - node_data.append( - [inbound_layer.name, new_node_index, tensor_index, kwargs]) - filtered_inbound_nodes.append(node_data) - layer_configs.append({ - 'name': layer.name, - 'class_name': layer_class_name, - 'config': layer_config, - 'inbound_nodes': filtered_inbound_nodes, - }) - config['layers'] = layer_configs - - # Gather info about inputs and outputs. - model_inputs = [] - for i in range(len(self._input_layers)): - layer, node_index, tensor_index = self._input_coordinates[i] - node_key = tf_network._make_node_key(layer.name, - node_index) - if node_key not in self._network_nodes: - continue - new_node_index = node_conversion_map[node_key] - model_inputs.append([layer.name, new_node_index, tensor_index]) - config['input_layers'] = model_inputs - model_outputs = [] - for i in range(len(self._output_layers)): - layer, node_index, tensor_index = self._output_coordinates[i] - node_key = tf_network._make_node_key(layer.name, - node_index) - if node_key not in self._network_nodes: - continue - new_node_index = node_conversion_map[node_key] - model_outputs.append([layer.name, new_node_index, tensor_index]) - config['output_layers'] = model_outputs - return copy.deepcopy(config) - - @classmethod - def from_config(cls, config, custom_objects=None): - """Instantiates a Model from its config (output of `get_config()`). - - Arguments: - config: Model config dictionary. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A model instance. - - Raises: - ValueError: In case of improperly formatted config dict. - """ - # Layer instances created during - # the graph reconstruction process - created_layers = {} - - # Dictionary mapping layer instances to - # node data that specifies a layer call. - # It acts as a queue that maintains any unprocessed - # layer call until it becomes possible to process it - # (i.e. until the input tensors to the call all exist). - unprocessed_nodes = {} - - def add_unprocessed_node(layer, node_data): - if layer not in unprocessed_nodes: - unprocessed_nodes[layer] = [node_data] - else: - unprocessed_nodes[layer].append(node_data) - - def process_node(layer, node_data): - """Deserialize a node. - - Arguments: - layer: layer instance. - node_data: node config dict. - - Raises: - ValueError: In case of improperly formatted `node_data` dict. - """ - input_tensors = [] - for input_data in node_data: - inbound_layer_name = input_data[0] - inbound_node_index = input_data[1] - inbound_tensor_index = input_data[2] - if len(input_data) == 3: - kwargs = {} - elif len(input_data) == 4: - kwargs = input_data[3] - else: - raise ValueError('Improperly formatted model config.') - if inbound_layer_name not in created_layers: - add_unprocessed_node(layer, node_data) - return - inbound_layer = created_layers[inbound_layer_name] - if len(inbound_layer._inbound_nodes) <= inbound_node_index: - add_unprocessed_node(layer, node_data) - return - inbound_node = inbound_layer._inbound_nodes[inbound_node_index] - input_tensors.append(inbound_node.output_tensors[inbound_tensor_index]) - # Call layer on its inputs, thus creating the node - # and building the layer if needed. - if input_tensors: - if len(input_tensors) == 1: - layer(input_tensors[0], **kwargs) - else: - layer(input_tensors, **kwargs) - - def process_layer(layer_data): - """Deserialize a layer, then call it on appropriate inputs. - - Arguments: - layer_data: layer config dict. - - Raises: - ValueError: In case of improperly formatted `layer_data` dict. - """ - layer_name = layer_data['name'] - - # Instantiate layer. - from tensorflow.python.keras._impl.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top - - layer = deserialize_layer(layer_data, custom_objects=custom_objects) - created_layers[layer_name] = layer - - # Gather layer inputs. - inbound_nodes_data = layer_data['inbound_nodes'] - for node_data in inbound_nodes_data: - # We don't process nodes (i.e. make layer calls) - # on the fly because the inbound node may not yet exist, - # in case of layer shared at different topological depths - # (e.g. a model such as A(B(A(B(x))))) - add_unprocessed_node(layer, node_data) - - # First, we create all layers and enqueue nodes to be processed - for layer_data in config['layers']: - process_layer(layer_data) - # Then we process nodes in order of layer depth. - # Nodes that cannot yet be processed (if the inbound node - # does not yet exist) are re-enqueued, and the process - # is repeated until all nodes are processed. - while unprocessed_nodes: - for layer_data in config['layers']: - layer = created_layers[layer_data['name']] - if layer in unprocessed_nodes: - for node_data in unprocessed_nodes.pop(layer): - process_node(layer, node_data) - - name = config.get('name') - input_tensors = [] - output_tensors = [] - for layer_data in config['input_layers']: - layer_name, node_index, tensor_index = layer_data - assert layer_name in created_layers - layer = created_layers[layer_name] - layer_output_tensors = layer._inbound_nodes[node_index].output_tensors - input_tensors.append(layer_output_tensors[tensor_index]) - for layer_data in config['output_layers']: - layer_name, node_index, tensor_index = layer_data - assert layer_name in created_layers - layer = created_layers[layer_name] - layer_output_tensors = layer._inbound_nodes[node_index].output_tensors - output_tensors.append(layer_output_tensors[tensor_index]) - return cls(inputs=input_tensors, outputs=output_tensors, name=name) - - def save(self, filepath, overwrite=True, include_optimizer=True): - """Save the model to a single HDF5 file. - - The savefile includes: - - The model architecture, allowing to re-instantiate the model. - - The model weights. - - The state of the optimizer, allowing to resume training - exactly where you left off. - - This allows you to save the entirety of the state of a model - in a single file. - - Saved models can be reinstantiated via `keras.models.load_model`. - The model returned by `load_model` - is a compiled model ready to be used (unless the saved model - was never compiled in the first place). - - Arguments: - filepath: String, path to the file to save the weights to. - overwrite: Whether to silently overwrite any existing file at the - target location, or provide the user with a manual prompt. - include_optimizer: If True, save optimizer's state together. - - Example: - - ```python - from keras.models import load_model - - model.save('my_model.h5') # creates a HDF5 file 'my_model.h5' - del model # deletes the existing model - - # returns a compiled model - # identical to the previous one - model = load_model('my_model.h5') - ``` - """ - from tensorflow.python.keras._impl.keras.models import save_model # pylint: disable=g-import-not-at-top - save_model(self, filepath, overwrite, include_optimizer) - - def save_weights(self, filepath, overwrite=True): - """Dumps all layer weights to a HDF5 file. - - The weight file has: - - `layer_names` (attribute), a list of strings - (ordered names of model layers). - - For every layer, a `group` named `layer.name` - - For every such layer group, a group attribute `weight_names`, - a list of strings - (ordered names of weights tensor of the layer). - - For every weight in the layer, a dataset - storing the weight value, named after the weight tensor. - - Arguments: - filepath: String, path to the file to save the weights to. - overwrite: Whether to silently overwrite any existing file at the - target location, or provide the user with a manual prompt. - - Raises: - ImportError: If h5py is not available. - """ - if h5py is None: - raise ImportError('`save_weights` requires h5py.') - # If file exists and should not be overwritten: - if not overwrite and os.path.isfile(filepath): - proceed = ask_to_proceed_with_overwrite(filepath) - if not proceed: - return - f = h5py.File(filepath, 'w') - save_weights_to_hdf5_group(f, self.layers) - f.flush() - f.close() - - def load_weights(self, filepath, by_name=False): - """Loads all layer weights from a HDF5 save file. - - If `by_name` is False (default) weights are loaded - based on the network's topology, meaning the architecture - should be the same as when the weights were saved. - Note that layers that don't have weights are not taken - into account in the topological ordering, so adding or - removing layers is fine as long as they don't have weights. - - If `by_name` is True, weights are loaded into layers - only if they share the same name. This is useful - for fine-tuning or transfer-learning models where - some of the layers have changed. - - Arguments: - filepath: String, path to the weights file to load. - by_name: Boolean, whether to load weights by name - or by topological order. - - Raises: - ImportError: If h5py is not available. - """ - if h5py is None: - raise ImportError('`load_weights` requires h5py.') - f = h5py.File(filepath, mode='r') - if 'layer_names' not in f.attrs and 'model_weights' in f: - f = f['model_weights'] - if by_name: - load_weights_from_hdf5_group_by_name(f, self.layers) - else: - load_weights_from_hdf5_group(f, self.layers) - - if hasattr(f, 'close'): - f.close() - - def _updated_config(self): - """Util hared between different serialization methods. - - Returns: - Model config with Keras version information added. - """ - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top - - config = self.get_config() - model_config = { - 'class_name': self.__class__.__name__, - 'config': config, - 'keras_version': keras_version, - 'backend': K.backend() - } - return model_config - - def to_json(self, **kwargs): - """Returns a JSON string containing the network configuration. - - To load a network from a JSON save file, use - `keras.models.model_from_json(json_string, custom_objects={})`. - - Arguments: - **kwargs: Additional keyword arguments - to be passed to `json.dumps()`. - - Returns: - A JSON string. - """ - - def get_json_type(obj): - # If obj is any numpy type - if type(obj).__module__ == np.__name__: - return obj.item() - - # If obj is a python 'type' - if type(obj).__name__ == type.__name__: - return obj.__name__ - - raise TypeError('Not JSON Serializable:', obj) - - model_config = self._updated_config() - return json.dumps(model_config, default=get_json_type, **kwargs) - - def to_yaml(self, **kwargs): - """Returns a yaml string containing the network configuration. - - To load a network from a yaml save file, use - `keras.models.model_from_yaml(yaml_string, custom_objects={})`. - - `custom_objects` should be a dictionary mapping - the names of custom losses / layers / etc to the corresponding - functions / classes. - - Arguments: - **kwargs: Additional keyword arguments - to be passed to `yaml.dump()`. - - Returns: - A YAML string. - - Raises: - ImportError: if yaml module is not found. - """ - if yaml is None: - raise ImportError('Requires yaml module installed.') - return yaml.dump(self._updated_config(), **kwargs) - - def summary(self, line_length=None, positions=None, print_fn=None): - """Prints a string summary of the network. - - Arguments: - line_length: Total length of printed lines - (e.g. set this to adapt the display to different - terminal window sizes). - positions: Relative or absolute positions of log elements - in each line. If not provided, - defaults to `[.33, .55, .67, 1.]`. - print_fn: Print function to use. Defaults to `print`. - It will be called on each line of the summary. - You can set it to a custom function - in order to capture the string summary. - """ - print_layer_summary(self, - line_length=line_length, - positions=positions, - print_fn=print_fn) - - -def get_source_inputs(tensor, layer=None, node_index=None): - """Returns the list of input tensors necessary to compute `tensor`. - - Output will always be a list of tensors - (potentially with 1 element). - - Arguments: - tensor: The tensor to start from. - layer: Origin layer of the tensor. Will be - determined via tensor._keras_history if not provided. - node_index: Origin node index of the tensor. - - Returns: - List of input tensors. - """ - if not hasattr(tensor, '_keras_history'): - return tensor - - if layer is None or node_index: - layer, node_index, _ = tensor._keras_history - if not layer._inbound_nodes: - return [tensor] - else: - node = layer._inbound_nodes[node_index] - if not node.inbound_layers: - # Reached an Input layer, stop recursion. - return node.input_tensors - else: - source_tensors = [] - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - previous_sources = get_source_inputs(x, layer, node_index) - # Avoid input redundancy. - for x in previous_sources: - if x not in source_tensors: - source_tensors.append(x) - return source_tensors - - -def _to_list(x): - """Normalizes a list/tensor into a list. - - If a tensor is passed, we return - a list of size 1 containing the tensor. - - Arguments: - x: target object to be normalized. - - Returns: - A list. - """ - if isinstance(x, list): - return x - return [x] - - -def save_weights_to_hdf5_group(f, layers): - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top - - f.attrs['layer_names'] = [layer.name.encode('utf8') for layer in layers] - f.attrs['backend'] = K.backend().encode('utf8') - f.attrs['keras_version'] = str(keras_version).encode('utf8') - - for layer in layers: - g = f.create_group(layer.name) - symbolic_weights = layer.weights - weight_values = K.batch_get_value(symbolic_weights) - weight_names = [] - for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)): - if hasattr(w, 'name') and w.name: - name = str(w.name) - else: - name = 'param_' + str(i) - weight_names.append(name.encode('utf8')) - g.attrs['weight_names'] = weight_names - for name, val in zip(weight_names, weight_values): - param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) - if not val.shape: - # scalar - param_dset[()] = val - else: - param_dset[:] = val - - -def preprocess_weights_for_loading(layer, - weights, - original_keras_version=None, - original_backend=None): - """Converts layers weights from Keras 1 format to Keras 2. - - Arguments: - layer: Layer instance. - weights: List of weights values (Numpy arrays). - original_keras_version: Keras version for the weights, as a string. - original_backend: Keras backend the weights were trained with, - as a string. - - Returns: - A list of weights values (Numpy arrays). - """ - if layer.__class__.__name__ == 'Bidirectional': - num_weights_per_layer = len(weights) // 2 - forward_weights = preprocess_weights_for_loading( - layer.forward_layer, weights[:num_weights_per_layer], - original_keras_version, original_backend) - backward_weights = preprocess_weights_for_loading( - layer.backward_layer, weights[num_weights_per_layer:], - original_keras_version, original_backend) - weights = forward_weights + backward_weights - - if original_keras_version == '1': - if layer.__class__.__name__ == 'TimeDistributed': - weights = preprocess_weights_for_loading( - layer.layer, weights, original_keras_version, original_backend) - - if layer.__class__.__name__ == 'Conv1D': - shape = weights[0].shape - # Handle Keras 1.1 format - if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: - # Legacy shape: - # (filters, input_dim, filter_length, 1) - assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], - 1) - weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) - weights[0] = weights[0][:, 0, :, :] - - if layer.__class__.__name__ == 'Conv2D': - if layer.data_format == 'channels_first': - # old: (filters, stack_size, kernel_rows, kernel_cols) - # new: (kernel_rows, kernel_cols, stack_size, filters) - weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) - - if layer.__class__.__name__ == 'Conv2DTranspose': - if layer.data_format == 'channels_last': - # old: (kernel_rows, kernel_cols, stack_size, filters) - # new: (kernel_rows, kernel_cols, filters, stack_size) - weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) - if layer.data_format == 'channels_first': - # old: (filters, stack_size, kernel_rows, kernel_cols) - # new: (kernel_rows, kernel_cols, filters, stack_size) - weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) - - if layer.__class__.__name__ == 'Conv3D': - if layer.data_format == 'channels_first': - # old: (filters, stack_size, ...) - # new: (..., stack_size, filters) - weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) - - if layer.__class__.__name__ == 'GRU': - if len(weights) == 9: - kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) - recurrent_kernel = np.concatenate( - [weights[1], weights[4], weights[7]], axis=-1) - bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) - weights = [kernel, recurrent_kernel, bias] - - if layer.__class__.__name__ == 'LSTM': - if len(weights) == 12: - # old: i, c, f, o - # new: i, f, c, o - kernel = np.concatenate( - [weights[0], weights[6], weights[3], weights[9]], axis=-1) - recurrent_kernel = np.concatenate( - [weights[1], weights[7], weights[4], weights[10]], axis=-1) - bias = np.concatenate( - [weights[2], weights[8], weights[5], weights[11]], axis=-1) - weights = [kernel, recurrent_kernel, bias] - - if layer.__class__.__name__ == 'ConvLSTM2D': - if len(weights) == 12: - kernel = np.concatenate( - [weights[0], weights[6], weights[3], weights[9]], axis=-1) - recurrent_kernel = np.concatenate( - [weights[1], weights[7], weights[4], weights[10]], axis=-1) - bias = np.concatenate( - [weights[2], weights[8], weights[5], weights[11]], axis=-1) - if layer.data_format == 'channels_first': - # old: (filters, stack_size, kernel_rows, kernel_cols) - # new: (kernel_rows, kernel_cols, stack_size, filters) - kernel = np.transpose(kernel, (2, 3, 1, 0)) - recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) - weights = [kernel, recurrent_kernel, bias] - - if layer.__class__.__name__ in ['Model', 'Sequential']: - new_weights = [] - # trainable weights - for sublayer in layer.layers: - num_weights = len(sublayer.trainable_weights) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - - # non-trainable weights - for sublayer in layer.layers: - num_weights = len([ - l for l in sublayer.weights if l not in sublayer.trainable_weights - ]) - if num_weights > 0: - new_weights.extend( - preprocess_weights_for_loading( - layer=sublayer, - weights=weights[:num_weights], - original_keras_version=original_keras_version, - original_backend=original_backend)) - weights = weights[num_weights:] - weights = new_weights - - conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] - if layer.__class__.__name__ in conv_layers: - if original_backend == 'theano': - weights[0] = conv_utils.convert_kernel(weights[0]) - if layer.__class__.__name__ == 'ConvLSTM2D': - weights[1] = conv_utils.convert_kernel(weights[1]) - if K.int_shape(layer.weights[0]) != weights[0].shape: - weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) - if layer.__class__.__name__ == 'ConvLSTM2D': - weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) - - # Convert the weights of CuDNNLSTM so that they could be loaded into LSTM - if layer.__class__.__name__ == 'LSTM' and len(weights) == 3: - # Determine if loading a CuDNNLSTM layer from the number of bias weights: - # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) - # if there's no bias weight in the file, skip this conversion - units = weights[1].shape[0] - bias = weights[2] - if len(bias) == units * 8: - # reshape the kernels - kernels = np.split(weights[0], 4, axis=1) - kernels = [ - kernel.reshape(-1).reshape(kernel.shape, order='F') - for kernel in kernels - ] - weights[0] = np.concatenate(kernels, axis=1) - - # transpose the recurrent kernels - recurrent_kernels = np.split(weights[1], 4, axis=1) - recurrent_kernels = [kernel.T for kernel in recurrent_kernels] - weights[1] = np.concatenate(recurrent_kernels, axis=1) - - # split the bias into half and merge - weights[2] = bias[:units * 4] + bias[units * 4:] - - return weights - - -def load_weights_from_hdf5_group(f, layers): - """Implements topological (order-based) weight loading. - - Arguments: - f: A pointer to a HDF5 group. - layers: a list of target layers. - - Raises: - ValueError: in case of mismatch between provided layers - and weights file. - """ - if 'keras_version' in f.attrs: - original_keras_version = f.attrs['keras_version'].decode('utf8') - else: - original_keras_version = '1' - if 'backend' in f.attrs: - original_backend = f.attrs['backend'].decode('utf8') - else: - original_backend = None - - filtered_layers = [] - for layer in layers: - weights = layer.weights - if weights: - filtered_layers.append(layer) - - layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] - filtered_layer_names = [] - for name in layer_names: - g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - if weight_names: - filtered_layer_names.append(name) - layer_names = filtered_layer_names - if len(layer_names) != len(filtered_layers): - raise ValueError('You are trying to load a weight file ' - 'containing ' + str(len(layer_names)) + - ' layers into a model with ' + str(len(filtered_layers)) + - ' layers.') - - # We batch weight value assignments in a single backend call - # which provides a speedup in TensorFlow. - weight_value_tuples = [] - for k, name in enumerate(layer_names): - g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - weight_values = [g[weight_name] for weight_name in weight_names] - layer = filtered_layers[k] - symbolic_weights = layer.weights - weight_values = preprocess_weights_for_loading( - layer, weight_values, original_keras_version, original_backend) - if len(weight_values) != len(symbolic_weights): - raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + - '" in the current model) was found to ' - 'correspond to layer ' + name + ' in the save file. ' - 'However the new layer ' + layer.name + ' expects ' + - str(len(symbolic_weights)) + - ' weights, but the saved weights have ' + - str(len(weight_values)) + ' elements.') - weight_value_tuples += zip(symbolic_weights, weight_values) - K.batch_set_value(weight_value_tuples) - - -def load_weights_from_hdf5_group_by_name(f, layers): - """Implements name-based weight loading. - - (instead of topological weight loading). - - Layers that have no matching name are skipped. - - Arguments: - f: A pointer to a HDF5 group. - layers: a list of target layers. - - Raises: - ValueError: in case of mismatch between provided layers - and weights file. - """ - if 'keras_version' in f.attrs: - original_keras_version = f.attrs['keras_version'].decode('utf8') - else: - original_keras_version = '1' - if 'backend' in f.attrs: - original_backend = f.attrs['backend'].decode('utf8') - else: - original_backend = None - - # New file format. - layer_names = [n.decode('utf8') for n in f.attrs['layer_names']] - - # Reverse index of layer name to list of layers with name. - index = {} - for layer in layers: - if layer.name: - index.setdefault(layer.name, []).append(layer) - - # We batch weight value assignments in a single backend call - # which provides a speedup in TensorFlow. - weight_value_tuples = [] - for k, name in enumerate(layer_names): - g = f[name] - weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] - weight_values = [g[weight_name] for weight_name in weight_names] - - for layer in index.get(name, []): - symbolic_weights = layer.weights - weight_values = preprocess_weights_for_loading( - layer, weight_values, original_keras_version, original_backend) - if len(weight_values) != len(symbolic_weights): - raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + - '") expects ' + str(len(symbolic_weights)) + - ' weight(s), but the saved weights' + ' have ' + - str(len(weight_values)) + ' element(s).') - # Set values. - for i in range(len(weight_values)): - weight_value_tuples.append((symbolic_weights[i], weight_values[i])) - K.batch_set_value(weight_value_tuples) - - -def shape_type_conversion(fn): - """Decorator that handles tuple/TensorShape conversion. - - Used in `compute_output_shape` and `build`. - - Arguments: - fn: function to wrap. - - Returns: - Wrapped function. - """ - - def wrapper(instance, input_shape): - if input_shape is not None: - if isinstance(input_shape, list): - input_shape = [ - tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape] - else: - input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list()) - output_shape = fn(instance, input_shape) - if output_shape is not None: - if isinstance(output_shape, list): - return [tensor_shape.TensorShape(x) for x in output_shape] - return tensor_shape.TensorShape(output_shape) - - return wrapper diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py index 0673e4237674cf01c3df5ab7dc8e13f1de03e477..04434323d6a9f8e12ad8f45c1e83819dfa8b3b96 100644 --- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py @@ -18,13 +18,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import shutil - import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util from tensorflow.python.keras._impl import keras +from tensorflow.python.layers import base as tf_base_layers from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops @@ -35,37 +36,255 @@ try: except ImportError: yaml = None -try: - import h5py # pylint:disable=g-import-not-at-top -except ImportError: - h5py = None - class TopologyConstructionTest(test.TestCase): - def test_get_updates_for(self): - a = keras.layers.Input(shape=(1,)) - dense_layer = keras.layers.Dense(1) - dense_layer.build((None, 1)) - update_1 = state_ops.assign_add(dense_layer.kernel, a) - update_2 = state_ops.assign_add(dense_layer.kernel, [[1.]]) - dense_layer.add_update(update_1, inputs=a) - dense_layer.add_update(update_2, inputs=None) - - self.assertListEqual(dense_layer.get_updates_for(a), [update_1]) - self.assertListEqual(dense_layer.get_updates_for(None), [update_2]) - - def test_get_losses_for(self): - a = keras.layers.Input(shape=(1,)) - dense_layer = keras.layers.Dense(1) - dense_layer.build((None, 1)) - loss_1 = math_ops.reduce_sum(a) - loss_2 = math_ops.reduce_sum(dense_layer.kernel) - dense_layer.add_loss(loss_1, inputs=a) - dense_layer.add_loss(loss_2, inputs=None) - - self.assertListEqual(dense_layer.get_losses_for(a), [loss_1]) - self.assertListEqual(dense_layer.get_losses_for(None), [loss_2]) + def test_get_updates(self): + + class MyLayer(keras.layers.Layer): + + def build(self, input_shape): + self.a = self.add_variable('a', + (1, 1), + 'float32', + trainable=False) + self.b = self.add_variable('b', + (1, 1), + 'float32', + trainable=False) + self.add_update(state_ops.assign_add(self.a, [[1.]])) + self.built = True + + def call(self, inputs): + self.add_update(state_ops.assign_add(self.a, inputs), + inputs=True) + return inputs + 1 + + x1 = keras.Input(shape=(1,)) + layer = MyLayer() + _ = layer.apply(x1) + + self.assertEqual(len(layer.updates), 2) + self.assertEqual(len(layer.get_updates_for(x1)), 1) + self.assertEqual(len(layer.get_updates_for(None)), 1) + + x2 = keras.Input(shape=(1,)) + y2 = layer.apply(x2) + + self.assertEqual(len(layer.updates), 3) + self.assertEqual(len(layer.get_updates_for(x1)), 1) + self.assertEqual(len(layer.get_updates_for(x2)), 1) + self.assertEqual(len(layer.get_updates_for(None)), 1) + + network = keras.engine.Network(x2, y2) + self.assertEqual(len(network.updates), 2) + self.assertEqual(len(network.get_updates_for(x1)), 0) + self.assertEqual(len(network.get_updates_for(x2)), 1) + self.assertEqual(len(network.get_updates_for(None)), 1) + + x3 = keras.Input(shape=(1,)) + _ = layer.apply(x3) + self.assertEqual(len(network.updates), 2) + + x4 = keras.Input(shape=(1,)) + _ = network(x4) + self.assertEqual(len(network.updates), 3) + self.assertEqual(len(network.get_updates_for(x2)), 1) + self.assertEqual(len(network.get_updates_for(x4)), 1) + self.assertEqual(len(network.get_updates_for(None)), 1) + + network.add_update(state_ops.assign_add(layer.a, [[1]])) + self.assertEqual(len(network.updates), 4) + self.assertEqual(len(network.get_updates_for(None)), 2) + + network.add_update(state_ops.assign_add(layer.a, x4), inputs=True) + self.assertEqual(len(network.updates), 5) + self.assertEqual(len(network.get_updates_for(x4)), 2) + + def test_get_losses(self): + + class MyLayer(keras.layers.Layer): + + def build(self, input_shape): + self.a = self.add_variable('a', + (1, 1), + 'float32', + trainable=False) + self.b = self.add_variable('b', + (1, 1), + 'float32', + trainable=False) + self.add_loss(math_ops.reduce_sum(self.a)) + self.built = True + + def call(self, inputs): + self.add_loss(math_ops.reduce_sum(inputs), + inputs=True) + return inputs + 1 + + x1 = keras.Input(shape=(1,)) + layer = MyLayer() + _ = layer.apply(x1) + + self.assertEqual(len(layer.losses), 2) + self.assertEqual(len(layer.get_losses_for(x1)), 1) + self.assertEqual(len(layer.get_losses_for(None)), 1) + + x2 = keras.Input(shape=(1,)) + y2 = layer.apply(x2) + + self.assertEqual(len(layer.losses), 3) + self.assertEqual(len(layer.get_losses_for(x1)), 1) + self.assertEqual(len(layer.get_losses_for(x2)), 1) + self.assertEqual(len(layer.get_losses_for(None)), 1) + + network = keras.engine.Network(x2, y2) + self.assertEqual(len(network.losses), 2) + self.assertEqual(len(network.get_losses_for(x1)), 0) + self.assertEqual(len(network.get_losses_for(x2)), 1) + self.assertEqual(len(network.get_losses_for(None)), 1) + + x3 = keras.Input(shape=(1,)) + _ = layer.apply(x3) + self.assertEqual(len(network.losses), 2) + + x4 = keras.Input(shape=(1,)) + _ = network(x4) + self.assertEqual(len(network.losses), 3) + self.assertEqual(len(network.get_losses_for(x2)), 1) + self.assertEqual(len(network.get_losses_for(x4)), 1) + self.assertEqual(len(network.get_losses_for(None)), 1) + + network.add_loss(math_ops.reduce_sum(layer.a)) + self.assertEqual(len(network.losses), 4) + self.assertEqual(len(network.get_losses_for(None)), 2) + + network.add_loss(math_ops.reduce_sum(x4), inputs=True) + self.assertEqual(len(network.losses), 5) + self.assertEqual(len(network.get_losses_for(x4)), 2) + + def testTopologicalAttributes(self): + # test layer attributes / methods related to cross-layer connectivity. + a = keras.Input(shape=(32,), name='input_a') + b = keras.Input(shape=(32,), name='input_b') + + # test input, output, input_shape, output_shape + test_layer = keras.layers.Dense(16, name='test_layer') + a_test = test_layer(a) + self.assertEqual(test_layer.input, a) + self.assertEqual(test_layer.output, a_test) + self.assertEqual(test_layer.input_shape, (None, 32)) + self.assertEqual(test_layer.output_shape, (None, 16)) + + # test `get_*_at` methods + dense = keras.layers.Dense(16, name='dense_1') + a_2 = dense(a) + b_2 = dense(b) + + self.assertEqual(dense.get_input_at(0), a) + self.assertEqual(dense.get_input_at(1), b) + self.assertEqual(dense.get_output_at(0), a_2) + self.assertEqual(dense.get_output_at(1), b_2) + self.assertEqual(dense.get_input_shape_at(0), (None, 32)) + self.assertEqual(dense.get_input_shape_at(1), (None, 32)) + self.assertEqual(dense.get_output_shape_at(0), (None, 16)) + self.assertEqual(dense.get_output_shape_at(1), (None, 16)) + + # Test invalid value for attribute retrieval. + with self.assertRaises(ValueError): + dense.get_input_at(2) + with self.assertRaises(AttributeError): + new_dense = keras.layers.Dense(16) + _ = new_dense.input + with self.assertRaises(AttributeError): + new_dense = keras.layers.Dense(16) + _ = new_dense.output + with self.assertRaises(AttributeError): + new_dense = keras.layers.Dense(16) + _ = new_dense.output_shape + with self.assertRaises(AttributeError): + new_dense = keras.layers.Dense(16) + _ = new_dense.input_shape + with self.assertRaises(AttributeError): + new_dense = keras.layers.Dense(16) + a = keras.Input(shape=(3, 32)) + a = keras.Input(shape=(5, 32)) + a_2 = dense(a) + b_2 = dense(b) + _ = new_dense.input_shape + with self.assertRaises(AttributeError): + new_dense = keras.layers.Dense(16) + a = keras.Input(shape=(3, 32)) + a = keras.Input(shape=(5, 32)) + a_2 = dense(a) + b_2 = dense(b) + _ = new_dense.output_shape + + def testTopologicalAttributesMultiOutputLayer(self): + + class PowersLayer(keras.layers.Layer): + + def call(self, inputs): + return [inputs**2, inputs**3] + + x = keras.Input(shape=(32,)) + test_layer = PowersLayer() + p1, p2 = test_layer(x) # pylint: disable=not-callable + + self.assertEqual(test_layer.input, x) + self.assertEqual(test_layer.output, [p1, p2]) + self.assertEqual(test_layer.input_shape, (None, 32)) + self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)]) + + def testTopologicalAttributesMultiInputLayer(self): + + class AddLayer(keras.layers.Layer): + + def call(self, inputs): + assert len(inputs) == 2 + return inputs[0] + inputs[1] + + a = keras.Input(shape=(32,)) + b = keras.Input(shape=(32,)) + test_layer = AddLayer() + y = test_layer([a, b]) # pylint: disable=not-callable + + self.assertEqual(test_layer.input, [a, b]) + self.assertEqual(test_layer.output, y) + self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)]) + self.assertEqual(test_layer.output_shape, (None, 32)) + + def testBasicNetwork(self): + # minimum viable network + x = keras.Input(shape=(32,)) + dense = keras.layers.Dense(2) + y = dense(x) + network = keras.engine.Network(x, y, name='dense_network') + + # test basic attributes + self.assertEqual(network.name, 'dense_network') + self.assertEqual(len(network.layers), 2) # InputLayer + Dense + self.assertEqual(network.layers[1], dense) + self.assertEqual(network.weights, dense.weights) + self.assertEqual(network.trainable_weights, dense.trainable_weights) + self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) + + # test callability on Input + x_2 = keras.Input(shape=(32,)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 2]) + + # test callability on regular tensor + x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 2]) + + # test network `trainable` attribute + network.trainable = False + self.assertEqual(network.weights, dense.weights) + self.assertEqual(network.trainable_weights, []) + self.assertEqual(network.non_trainable_weights, + dense.trainable_weights + dense.non_trainable_weights) def test_trainable_weights(self): a = keras.layers.Input(shape=(2,)) @@ -108,41 +327,6 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual(model.trainable_weights, []) self.assertListEqual(model.non_trainable_weights, weights) - def test_weight_loading(self): - with self.test_session(): - a = keras.layers.Input(shape=(2,)) - x = keras.layers.Dense(3)(a) - b = keras.layers.Dense(1)(x) - model = keras.models.Model(a, b) - - x = np.random.random((3, 2)) - ref_y = model.predict(x) - weights = model.get_weights() - model.set_weights(weights) - y = model.predict(x) - self.assertAllClose(ref_y, y) - - with self.assertRaises(ValueError): - model.set_weights(weights[1:]) - with self.assertRaises(ValueError): - model.set_weights(weights[::-1]) - - if h5py is None: - return # Skip rest of test if H5py isn't available. - - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - - h5_path = os.path.join(temp_dir, 'test.h5') - model.save_weights(h5_path) - model.load_weights(h5_path) - y = model.predict(x) - self.assertAllClose(ref_y, y) - - model.load_weights(h5_path, by_name=True) - y = model.predict(x) - self.assertAllClose(ref_y, y) - def test_learning_phase(self): with self.test_session(): a = keras.layers.Input(shape=(32,), name='input_a') @@ -310,7 +494,7 @@ class TopologyConstructionTest(test.TestCase): self.assertListEqual([x.shape for x in fn_outputs], [(10, 64), (10, 5)]) # test get_source_inputs - self.assertListEqual(keras.engine.topology.get_source_inputs(c), [a, b]) + self.assertListEqual(keras.engine.network.get_source_inputs(c), [a, b]) # serialization / deserialization json_config = model.to_json() @@ -348,7 +532,7 @@ class TopologyConstructionTest(test.TestCase): e = keras.layers.Input(shape=(32,), name='input_e') f = keras.layers.Input(shape=(32,), name='input_f') g, h = model([e, f]) - self.assertEqual(g.name, 'model_1/dense_2/BiasAdd:0') + self.assertEqual(g.name, 'model/dense_2/BiasAdd:0') self.assertListEqual(g.get_shape().as_list(), c.get_shape().as_list()) self.assertListEqual(h.get_shape().as_list(), d.get_shape().as_list()) @@ -555,96 +739,62 @@ class TopologyConstructionTest(test.TestCase): model = keras.models.Model(a, b) self.assertEqual(model.output_mask.get_shape().as_list(), [None, 10]) - def test_weight_preprocessing(self): - input_dim = 3 - output_dim = 3 - size = 2 - cases = [ - [ - (keras.layers.Bidirectional(keras.layers.SimpleRNN(2))), - [np.random.random((2, 1)), np.random.random((2, 1))], - (None, 3, 2), - ], - [ - (keras.layers.TimeDistributed(keras.layers.Dense(1))), - [np.random.random((2, 1)), np.random.random((1,))], - (None, 3, 2), - ], - [ - (keras.layers.Conv1D(output_dim, size, use_bias=False)), - [np.random.random((output_dim, input_dim, size, 1))], - (None, 4, input_dim), - ], - [ - (keras.layers.Conv2D(output_dim, size, - use_bias=False, data_format='channels_first')), - [np.random.random((output_dim, input_dim, size, size))], - (None, input_dim, 4, 4), - ], - [ - (keras.layers.Conv2DTranspose(output_dim, size, - use_bias=False, - data_format='channels_first')), - [np.random.random((output_dim, input_dim, size, size))], - (None, input_dim, 4, 4), - ], - [ - (keras.layers.Conv2DTranspose(output_dim, size, - use_bias=False, - data_format='channels_last')), - [np.random.random((size, size, input_dim, output_dim))], - (None, 4, 4, input_dim), - ], - [ - (keras.layers.Conv3D(output_dim, size, - use_bias=False, data_format='channels_first')), - [np.random.random((output_dim, input_dim, size, size, size))], - (None, input_dim, 4, 4, 4), - ], - [ - (keras.layers.GRU(output_dim)), - [np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,))], - (None, 4, input_dim), - ], - [ - (keras.layers.LSTM(output_dim)), - [np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,)), - np.random.random((input_dim, output_dim)), - np.random.random((output_dim, output_dim)), - np.random.random((output_dim,))], - (None, 4, input_dim), - ], - ] - for layer, weights, input_shape in cases: - layer.build(input_shape) - _ = keras.engine.topology.preprocess_weights_for_loading( - layer, weights, original_keras_version='1') - - model = keras.models.Sequential([keras.layers.Dense(2, input_dim=2)]) - _ = keras.engine.topology.preprocess_weights_for_loading( - model, model.weights, original_keras_version='1') - - x = keras.Input((2,)) - y = keras.layers.Dense(2)(x) - model = keras.models.Model(x, y) - _ = keras.engine.topology.preprocess_weights_for_loading( - model, model.weights, original_keras_version='1') + def testMaskingSingleInput(self): + + class MaskedLayer(keras.layers.Layer): + + def call(self, inputs, mask=None): + if mask is not None: + return inputs * mask + return inputs + + def compute_mask(self, inputs, mask=None): + return array_ops.ones_like(inputs) + + if context.in_graph_mode(): + x = keras.Input(shape=(32,)) + y = MaskedLayer()(x) # pylint: disable=not-callable + network = keras.engine.Network(x, y) + + # test callability on Input + x_2 = keras.Input(shape=(32,)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 32]) + + # test callability on regular tensor + x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) + y_2 = network(x_2) + self.assertEqual(y_2.get_shape().as_list(), [None, 32]) + else: + a = constant_op.constant([2] * 32) + mask = constant_op.constant([0, 1] * 16) + a._keras_mask = mask + b = MaskedLayer().apply(a) + self.assertTrue(hasattr(b, '_keras_mask')) + self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), + self.evaluate(getattr(b, '_keras_mask'))) + self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) + + def test_activity_regularization_with_model_composition(self): + + def reg(x): + return keras.backend.sum(x) + + net_a_input = keras.Input((2,)) + net_a = net_a_input + net_a = keras.layers.Dense(2, kernel_initializer='ones', + use_bias=False, + activity_regularizer=reg)(net_a) + model_a = keras.Model([net_a_input], [net_a]) + + net_b_input = keras.Input((2,)) + net_b = model_a(net_b_input) + model_b = keras.Model([net_b_input], [net_b]) + + model_b.compile(optimizer='sgd', loss=None) + x = np.ones((1, 2)) + loss = model_b.evaluate(x) + self.assertEqual(loss, 4.) def test_layer_sharing_at_heterogenous_depth(self): with self.test_session(): @@ -694,5 +844,92 @@ class TopologyConstructionTest(test.TestCase): output_val_2 = m2.predict(x_val) self.assertAllClose(output_val, output_val_2, atol=1e-6) + def test_explicit_training_argument(self): + with self.test_session(): + a = keras.layers.Input(shape=(2,)) + b = keras.layers.Dropout(0.5)(a) + base_model = keras.models.Model(a, b) + + a = keras.layers.Input(shape=(2,)) + b = base_model(a, training=False) + model = keras.models.Model(a, b) + + x = np.ones((100, 2)) + y = np.ones((100, 2)) + model.compile(optimizer='sgd', loss='mse') + loss = model.train_on_batch(x, y) + self.assertEqual(loss, 0) # In inference mode, output is equal to input. + + a = keras.layers.Input(shape=(2,)) + b = base_model(a, training=True) + model = keras.models.Model(a, b) + preds = model.predict(x) + self.assertEqual(np.min(preds), 0.) # At least one unit was dropped. + + +class DeferredModeTest(test.TestCase): + + def testDeferredTensorAttributes(self): + x = tf_base_layers._DeferredTensor(shape=(None, 2), + dtype='float32', + name='x') + self.assertEqual(str(x), + 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') + self.assertEqual(repr(x), + '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') + + @test_util.run_in_graph_and_eager_modes() + def testSimpleNetworkBuilding(self): + inputs = keras.engine.Input(shape=(32,)) + if context.in_eager_mode(): + self.assertIsInstance(inputs, tf_base_layers._DeferredTensor) + self.assertEqual(inputs.dtype.name, 'float32') + self.assertEqual(inputs.shape.as_list(), [None, 32]) + + x = keras.layers.Dense(2)(inputs) + if context.in_eager_mode(): + self.assertIsInstance(x, tf_base_layers._DeferredTensor) + self.assertEqual(x.dtype.name, 'float32') + self.assertEqual(x.shape.as_list(), [None, 2]) + + outputs = keras.layers.Dense(4)(x) + network = keras.engine.Network(inputs, outputs) + self.assertIsInstance(network, keras.engine.Network) + + if context.in_eager_mode(): + # It should be possible to call such a network on EagerTensors. + inputs = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + outputs = network(inputs) + self.assertEqual(outputs.shape.as_list(), [10, 4]) + + @test_util.run_in_graph_and_eager_modes() + def testMultiIONetworkbuilding(self): + input_a = keras.engine.Input(shape=(32,)) + input_b = keras.engine.Input(shape=(16,)) + a = keras.layers.Dense(16)(input_a) + + class AddLayer(keras.layers.Layer): + + def call(self, inputs): + return inputs[0] + inputs[1] + + def compute_output_shape(self, input_shape): + return input_shape[0] + + c = AddLayer()([a, input_b]) # pylint: disable=not-callable + c = keras.layers.Dense(2)(c) + + network = keras.engine.Network([input_a, input_b], [a, c]) + if context.in_eager_mode(): + a_val = constant_op.constant( + np.random.random((10, 32)).astype('float32')) + b_val = constant_op.constant( + np.random.random((10, 16)).astype('float32')) + outputs = network([a_val, b_val]) + self.assertEqual(len(outputs), 2) + self.assertEqual(outputs[0].shape.as_list(), [10, 16]) + self.assertEqual(outputs[1].shape.as_list(), [10, 2]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py index 118598831d0c906a2e5229f4f26441180d291ccf..63bea08ac55ca185f952a805282abb5872d08fbd 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training.py +++ b/tensorflow/python/keras/_impl/keras/engine/training.py @@ -24,17 +24,23 @@ import numpy as np from tensorflow.python.eager import context from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module from tensorflow.python.keras._impl.keras import optimizers from tensorflow.python.keras._impl.keras.engine import training_eager -from tensorflow.python.keras._impl.keras.engine.topology import Network +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.network import Network from tensorflow.python.keras._impl.keras.utils.data_utils import GeneratorEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import OrderedEnqueuer from tensorflow.python.keras._impl.keras.utils.data_utils import Sequence +from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.layers.base import _DeferredTensor +from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import optimizer as tf_optimizer_module from tensorflow.python.util.tf_export import tf_export @@ -220,9 +226,9 @@ def _check_array_lengths(inputs, targets, weights=None): # return a set with the variation between # different shapes, with None => 0 if x is None: - return {0} + return {} else: - return set([0 if y is None else y.shape[0] for y in x]) + return set([y.shape[0] for y in x if y is not None]) set_x = set_of_lengths(inputs) set_y = set_of_lengths(targets) @@ -254,7 +260,8 @@ def _check_array_lengths(inputs, targets, weights=None): def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): """Does validation on the compatibility of targets and loss functions. - This helps prevent users from using loss functions incorrectly. + This helps prevent users from using loss functions incorrectly. This check + is purely for UX purposes. Arguments: targets: list of Numpy arrays of targets. @@ -270,7 +277,7 @@ def _check_loss_and_target_compatibility(targets, loss_fns, output_shapes): losses.categorical_crossentropy } for y, loss, shape in zip(targets, loss_fns, output_shapes): - if loss is None: + if y is None or loss is None or tensor_util.is_tensor(y): continue if loss is losses.categorical_crossentropy: if y.shape[-1] == 1: @@ -360,62 +367,6 @@ def _batch_shuffle(index_array, batch_size): return np.append(index_array, last_batch) -def _make_batches(size, batch_size): - """Returns a list of batch indices (tuples of indices). - - Arguments: - size: Integer, total size of the data to slice into batches. - batch_size: Integer, batch size. - - Returns: - A list of tuples of array indices. - """ - num_batches = (size + batch_size - 1) // batch_size # round up - return [(i * batch_size, min(size, (i + 1) * batch_size)) - for i in range(num_batches)] - - -def _slice_arrays(arrays, start=None, stop=None): - """Slice an array or list of arrays. - - This takes an array-like, or a list of - array-likes, and outputs: - - arrays[start:stop] if `arrays` is an array-like - - [x[start:stop] for x in arrays] if `arrays` is a list - - Can also work on list/array of indices: `_slice_arrays(x, indices)` - - Arguments: - arrays: Single array or list of arrays. - start: can be an integer index (start index) - or a list/array of indices - stop: integer (stop index); should be None if - `start` was a list. - - Returns: - A slice of the array(s). - """ - if arrays is None: - return [None] - elif isinstance(arrays, list): - if hasattr(start, '__len__'): - # hdf5 datasets only support list objects as indices - if hasattr(start, 'shape'): - start = start.tolist() - return [None if x is None else x[start] for x in arrays] - else: - return [None if x is None else x[start:stop] for x in arrays] - else: - if hasattr(start, '__len__'): - if hasattr(start, 'shape'): - start = start.tolist() - return arrays[start] - elif hasattr(start, '__getitem__'): - return arrays[start:stop] - else: - return [None] - - def _weighted_masked_objective(fn): """Adds support for masking and sample-weighting to an objective function. @@ -539,7 +490,7 @@ def _standardize_weights(y, raise ValueError('`class_weight` not supported for ' '3+ dimensional targets.') if y.shape[1] > 1: - y_classes = y.argmax(axis=1) + y_classes = np.argmax(y, axis=1) elif y.shape[1] == 1: y_classes = np.reshape(y, y.shape[0]) else: @@ -558,20 +509,75 @@ def _standardize_weights(y, (existing_classes - existing_class_weight)) return weights else: - if sample_weight_mode is None: - return np.ones((y.shape[0],), dtype=K.floatx()) - else: - return np.ones((y.shape[0], y.shape[1]), dtype=K.floatx()) + return None @tf_export('keras.models.Model', 'keras.Model') class Model(Network): - """The `Model` class adds training & evaluation routines to a `Network`. + """`Model` groups layers into an object with training and inference features. + + There are two ways to instantiate a `Model`: + + 1 - With the "functional API", where you start from `Input`, + you chain layer calls to specify the model's forward pass, + and finally you create your model from inputs and outputs: + + ```python + import tensorflow as tf + + inputs = tf.keras.Input(shape=(3,)) + x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) + outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) + model = tf.keras.Model(inputs=inputs, outputs=outputs) + ``` + + 2 - By subclassing the `Model` class: in that case, you should define your + layers in `__init__` and you should implement the model's forward pass + in `call`. + + ```python + import tensorflow as tf + + class MyModel(tf.keras.Model): + + def __init__(self): + self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) + self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) + + def call(self, inputs): + x = self.dense1(inputs) + return self.dense2(x) + + model = MyModel() + ``` + + If you subclass `Model`, you can optionally have + a `training` argument (boolean) in `call`, which you can use to specify + a different behavior in training and inference: + + ```python + import tensorflow as tf + + class MyModel(tf.keras.Model): + + def __init__(self): + self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu) + self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax) + self.dropout = tf.keras.layers.Dropout(0.5) + + def call(self, inputs, training=False): + x = self.dense1(inputs) + if training: + x = self.dropout(x, training=training) + return self.dense2(x) + + model = MyModel() + ``` """ def compile(self, optimizer, - loss, + loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, @@ -628,15 +634,29 @@ class Model(Network): """ loss = loss or {} if context.in_eager_mode() and not isinstance( - optimizer, tf_optimizer_module.Optimizer): + optimizer, (tf_optimizer_module.Optimizer, optimizers.TFOptimizer)): raise ValueError('Only TF native optimizers are supported in Eager mode.') self.optimizer = optimizers.get(optimizer) self.loss = loss + self.metrics = metrics or [] self.loss_weights = loss_weights if context.in_eager_mode() and sample_weight_mode is not None: raise ValueError('sample_weight_mode is not supported in Eager mode.') self.sample_weight_mode = sample_weight_mode + if context.in_eager_mode() and weighted_metrics is not None: + raise ValueError('weighted_metrics is not supported in Eager mode.') + self.weighted_metrics = weighted_metrics + if context.in_eager_mode() and target_tensors is not None: + raise ValueError('target_tensors is not supported in Eager mode.') + self.target_tensors = target_tensors + + if not self.built: + # Model is not compilable because it does not know its number of inputs + # and outputs, nor their shapes and names. We will compile after the first + # time the model gets called on training data. + return + self._is_compiled = True # Prepare loss functions. if isinstance(loss, dict): @@ -719,8 +739,6 @@ class Model(Network): raise ValueError('target_tensors are not currently supported in Eager' 'mode.') self.total_loss = None - self.metrics = metrics - self.weighted_metrics = weighted_metrics self.metrics_tensors = [] self.metrics_names = ['loss'] for i in range(len(self.outputs)): @@ -732,16 +750,15 @@ class Model(Network): self._feed_sample_weight_modes.append(None) self.sample_weights = [] self.targets = [] - self._collected_trainable_weights = self.trainable_weights for i in range(len(self.outputs)): self._feed_output_names.append(self.output_names[i]) - + self._collected_trainable_weights = self.trainable_weights return # Prepare targets of model. self.targets = [] self._feed_targets = [] - if target_tensors is not None: + if target_tensors not in (None, []): if isinstance(target_tensors, list): if len(target_tensors) != len(self.outputs): raise ValueError( @@ -768,9 +785,9 @@ class Model(Network): if i in skip_target_indices: self.targets.append(None) else: - shape = self._internal_output_shapes[i] + shape = K.int_shape(self.outputs[i]) name = self.output_names[i] - if target_tensors is not None: + if target_tensors not in (None, []): target = target_tensors[i] else: target = None @@ -844,12 +861,12 @@ class Model(Network): sample_weights.append(None) else: if sample_weight_mode == 'temporal': - sample_weights.append( - K.placeholder(ndim=2, name=name + '_sample_weights')) + sample_weights.append(array_ops.placeholder_with_default( + [[1.]], shape=[None, None], name=name + '_sample_weights')) sample_weight_modes.append('temporal') else: - sample_weights.append( - K.placeholder(ndim=1, name=name + '_sample_weights')) + sample_weights.append(array_ops.placeholder_with_default( + [1.], shape=[None], name=name + '_sample_weights')) sample_weight_modes.append(None) self.sample_weight_modes = sample_weight_modes self._feed_sample_weight_modes = [] @@ -858,7 +875,6 @@ class Model(Network): self._feed_sample_weight_modes.append(self.sample_weight_modes[i]) # Prepare metrics. - self.metrics = metrics self.weighted_metrics = weighted_metrics self.metrics_names = ['loss'] self.metrics_tensors = [] @@ -901,14 +917,8 @@ class Model(Network): nested_metrics = _collect_metrics(metrics, self.output_names) nested_weighted_metrics = _collect_metrics(weighted_metrics, self.output_names) - - def append_metric(layer_index, metric_name, metric_tensor): - """Helper function used in loop below.""" - if len(self.output_names) > 1: - metric_name = self.output_names[layer_index] + '_' + metric_name - self.metrics_names.append(metric_name) - self.metrics_tensors.append(metric_tensor) - + self.metrics_updates = [] + self.stateful_metric_names = [] with K.name_scope('metrics'): for i in range(len(self.outputs)): if i in skip_target_indices: @@ -927,42 +937,65 @@ class Model(Network): if metric in ('accuracy', 'acc', 'crossentropy', 'ce'): # custom handling of accuracy/crossentropy # (because of class mode duality) - output_shape = self._internal_output_shapes[i] + output_shape = self.outputs[i].get_shape().as_list() if (output_shape[-1] == 1 or self.loss_functions[i] == losses.binary_crossentropy): # case: binary accuracy/crossentropy if metric in ('accuracy', 'acc'): - acc_fn = metrics_module.binary_accuracy + metric_fn = metrics_module.binary_accuracy elif metric in ('crossentropy', 'ce'): - acc_fn = metrics_module.binary_crossentropy + metric_fn = metrics_module.binary_crossentropy elif self.loss_functions[ i] == losses.sparse_categorical_crossentropy: # case: categorical accuracy/crossentropy with sparse targets if metric in ('accuracy', 'acc'): - acc_fn = metrics_module.sparse_categorical_accuracy + metric_fn = metrics_module.sparse_categorical_accuracy elif metric in ('crossentropy', 'ce'): - acc_fn = metrics_module.sparse_categorical_crossentropy + metric_fn = metrics_module.sparse_categorical_crossentropy else: # case: categorical accuracy/crossentropy if metric in ('accuracy', 'acc'): - acc_fn = metrics_module.categorical_accuracy + metric_fn = metrics_module.categorical_accuracy elif metric in ('crossentropy', 'ce'): - acc_fn = metrics_module.categorical_crossentropy + metric_fn = metrics_module.categorical_crossentropy if metric in ('accuracy', 'acc'): suffix = 'acc' elif metric in ('crossentropy', 'ce'): suffix = 'ce' - weighted_metric_fn = _weighted_masked_objective(acc_fn) + weighted_metric_fn = _weighted_masked_objective(metric_fn) metric_name = metric_name_prefix + suffix else: metric_fn = metrics_module.get(metric) weighted_metric_fn = _weighted_masked_objective(metric_fn) - metric_name = metric_name_prefix + metric_fn.__name__ + # Get metric name as string + if hasattr(metric_fn, 'name'): + metric_name = metric_fn.name + else: + metric_name = metric_fn.__name__ + metric_name = metric_name_prefix + metric_name with K.name_scope(metric_name): metric_result = weighted_metric_fn( y_true, y_pred, weights=weights, mask=masks[i]) - append_metric(i, metric_name, metric_result) + + # Append to self.metrics_names, self.metric_tensors, + # self.stateful_metric_names + if len(self.output_names) > 1: + metric_name = '%s_%s' % (self.output_names[i], metric_name) + # Dedupe name + j = 1 + base_metric_name = metric_name + while metric_name in self.metrics_names: + metric_name = '%s_%d' % (base_metric_name, j) + j += 1 + self.metrics_names.append(metric_name) + self.metrics_tensors.append(metric_result) + + # Keep track of state updates created by + # stateful metrics (i.e. metrics layers). + if isinstance(metric_fn, Layer): + self.stateful_metric_names.append(metric_name) + self.metrics_updates += metric_fn.updates handle_metrics(output_metrics) handle_metrics(output_weighted_metrics, weights=weights) @@ -1027,6 +1060,8 @@ class Model(Network): updates += self.get_updates_for(None) # Conditional updates relevant to this model updates += self.get_updates_for(self._feed_inputs) + # Stateful metrics updates + updates += self.metrics_updates # Gets loss and metrics. Updates weights at each call. self.train_function = K.function( inputs, [self.total_loss] + self.metrics_tensors, @@ -1047,7 +1082,7 @@ class Model(Network): # Does update the network states. self.test_function = K.function( inputs, [self.total_loss] + self.metrics_tensors, - updates=self.state_updates, + updates=self.state_updates + self.metrics_updates, name='test_function', **self._function_kwargs) @@ -1186,14 +1221,18 @@ class Model(Network): index_array = np.arange(num_train_samples) self.history = cbks.History() - callbacks = [cbks.BaseLogger()] + (callbacks or []) + [self.history] + all_callbacks = [cbks.BaseLogger( + stateful_metrics=self.stateful_metric_names)] if verbose: if steps_per_epoch is not None: count_mode = 'steps' else: count_mode = 'samples' - callbacks += [cbks.ProgbarLogger(count_mode)] - callbacks = cbks.CallbackList(callbacks) + all_callbacks.append( + cbks.ProgbarLogger( + count_mode, stateful_metrics=self.stateful_metric_names)) + all_callbacks += (callbacks or []) + [self.history] + callbacks = cbks.CallbackList(all_callbacks) out_labels = out_labels or [] # it's possible to callback a different model than self @@ -1227,6 +1266,11 @@ class Model(Network): indices_for_conversion_to_dense.append(i) for epoch in range(initial_epoch, epochs): + # Reset stateful metrics + for m in self.metrics: + if isinstance(m, Layer): + m.reset_states() + # Update callbacks callbacks.on_epoch_begin(epoch) epoch_logs = {} if steps_per_epoch is not None: @@ -1264,16 +1308,16 @@ class Model(Network): elif shuffle: np.random.shuffle(index_array) - batches = _make_batches(num_train_samples, batch_size) + batches = make_batches(num_train_samples, batch_size) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: - if isinstance(ins[-1], float): + if isinstance(ins[-1], int): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) except TypeError: raise TypeError('TypeError while preparing batch. ' 'If using HDF5 input data, ' @@ -1327,12 +1371,19 @@ class Model(Network): or list of arrays of predictions (if the model has multiple outputs). """ + if hasattr(self, 'metrics'): + for m in self.metrics: + if isinstance(m, Layer): + m.reset_states() + num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') if verbose == 1: if steps is not None: - progbar = Progbar(target=steps) + progbar = Progbar(target=steps, + stateful_metrics=self.stateful_metric_names) else: - progbar = Progbar(target=num_samples) + progbar = Progbar(target=num_samples, + stateful_metrics=self.stateful_metric_names) indices_for_conversion_to_dense = [] for i in range(len(self._feed_inputs)): @@ -1368,15 +1419,15 @@ class Model(Network): else: # Sample-based predictions. outs = [] - batches = _make_batches(num_samples, batch_size) + batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] - if ins and isinstance(ins[-1], float): + if ins and isinstance(ins[-1], int): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) for i in indices_for_conversion_to_dense: ins_batch[i] = ins_batch[i].toarray() @@ -1414,6 +1465,17 @@ class Model(Network): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. """ + if hasattr(self, 'metrics'): + for m in self.metrics: + if isinstance(m, Layer): + m.reset_states() + stateful_metric_indices = [ + i for i, name in enumerate(self.metrics_names) + if str(name) in self.stateful_metric_names + ] + else: + stateful_metric_indices = [] + num_samples = self._check_num_samples(ins, batch_size, steps, 'steps') outs = [] if verbose == 1: @@ -1437,7 +1499,10 @@ class Model(Network): for _ in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out else: if step == 0: outs.append(0.) @@ -1445,17 +1510,18 @@ class Model(Network): if verbose == 1: progbar.update(step + 1) for i in range(len(outs)): - outs[i] /= steps + if i not in stateful_metric_indices: + outs[i] /= steps else: - batches = _make_batches(num_samples, batch_size) + batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] - if isinstance(ins[-1], float): + if isinstance(ins[-1], int): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) for i in indices_for_conversion_to_dense: ins_batch[i] = ins_batch[i].toarray() @@ -1466,7 +1532,10 @@ class Model(Network): for batch_out in enumerate(batch_outs): outs.append(0.) for i, batch_out in enumerate(batch_outs): - outs[i] += batch_out * len(batch_ids) + if i in stateful_metric_indices: + outs[i] = batch_out + else: + outs[i] += batch_out * len(batch_ids) else: if batch_index == 0: outs.append(0.) @@ -1474,62 +1543,221 @@ class Model(Network): if verbose == 1: progbar.update(batch_end) for i in range(len(outs)): - outs[i] /= num_samples + if i not in stateful_metric_indices: + outs[i] /= num_samples if len(outs) == 1: return outs[0] return outs def _standardize_user_data(self, x, - y, + y=None, sample_weight=None, class_weight=None, - check_batch_axis=True, batch_size=None): - if not hasattr(self, 'optimizer'): - raise RuntimeError('You must compile a model before ' - 'training/testing. ' - 'Use `model.compile(optimizer, loss)`.') - - output_shapes = [] - for output_shape, loss_fn in zip(self._feed_output_shapes, - self._feed_loss_fns): - if loss_fn is losses.sparse_categorical_crossentropy: - output_shapes.append(output_shape[:-1] + (1,)) - elif (not hasattr(loss_fn, '__name__') or - getattr(losses, loss_fn.__name__, None) is None): - # If `loss_fn` is not a function (e.g. callable class) - # or if it not in the `losses` module, then - # it is a user-defined loss and we make no assumptions - # about it. - output_shapes.append(None) + """Runs validation checks on input and target data passed by the user. + + Also standardizes the data to lists of arrays, in order. + + Also builds and compiles the model on the fly if it is a subclassed model + that has never been called before (and thus has no inputs/outputs). + + This is a purely internal method, subject to refactoring at any time. + + Args: + x: An array or list of arrays, to be used as input data. If the model + has known, named inputs, this could also be a dict mapping input names + to the corresponding array. + y: An array or list of arrays, to be used as target data. If the model + has known, named outputs, this could also be a dict mapping output names + to the corresponding array. + sample_weight: An optional sample-weight array passed by the user to + weight the importance of each sample in `x`. + class_weight: An optional class-weight array by the user to + weight the importance of samples in `x` based on the class they belong + to, as conveyed by `y`. + batch_size: Integer batch size. If provided, it is used to run additional + validation checks on stateful models. + + Returns: + A tuple of 3 lists: input arrays, target arrays, sample-weight arrays. + If the model's input and targets are symbolic, these lists are empty + (since the model takes no user-provided data, instead the data comes + from the symbolic inputs/targets). + + Raises: + ValueError: In case of invalid user-provided data. + RuntimeError: If the model was never compiled. + """ + # First, we build/compile the model on the fly if necessary. + all_inputs = [] + if not self.built: + # We need to use `x` to set the model inputs. + # We type-check that `x` and `y` are either single arrays + # or lists of arrays. + if isinstance(x, (list, tuple)): + if not all(isinstance(v, np.ndarray) or + tensor_util.is_tensor(v) for v in x): + raise ValueError('Please provide as model inputs either a single ' + 'array or a list of arrays. You passed: x=' + str(x)) + all_inputs += list(x) + elif isinstance(x, dict): + raise ValueError('Please do not pass a dictionary as model inputs.') else: - output_shapes.append(output_shape) + if not isinstance(x, np.ndarray) and not tensor_util.is_tensor(x): + raise ValueError('Please provide as model inputs either a single ' + 'array or a list of arrays. You passed: x=' + str(x)) + all_inputs.append(x) + + # Build the model using the retrieved inputs (value or symbolic). + # If values, then in symbolic-mode placeholders will be created + # to match the value shapes. + if not self.inputs: + self._set_inputs(x) + + if y is not None: + if not self.optimizer: + raise RuntimeError('You must compile a model before ' + 'training/testing. ' + 'Use `model.compile(optimizer, loss)`.') + if not self._is_compiled: + # On-the-fly compilation of the model. + # We need to use `y` to set the model targets. + if isinstance(y, (list, tuple)): + if not all(isinstance(v, np.ndarray) or + tensor_util.is_tensor(v) for v in y): + raise ValueError('Please provide as model targets either a single ' + 'array or a list of arrays. ' + 'You passed: y=' + str(y)) + elif isinstance(y, dict): + raise ValueError('Please do not pass a dictionary as model targets.') + else: + if not isinstance(y, np.ndarray) and not tensor_util.is_tensor(y): + raise ValueError('Please provide as model targets either a single ' + 'array or a list of arrays. ' + 'You passed: y=' + str(y)) + + # Typecheck that all inputs are *either* value *or* symbolic. + # TODO(fchollet): this check could be removed in Eager mode? + if y is not None: + if isinstance(y, (list, tuple)): + all_inputs += list(y) + else: + all_inputs.append(y) + if any(tensor_util.is_tensor(v) for v in all_inputs): + if not all(tensor_util.is_tensor(v) for v in all_inputs): + raise ValueError('Do not pass inputs that mix Numpy arrays and ' + 'TensorFlow tensors. ' + 'You passed: x=' + str(x) + '; y=' + str(y)) + + if context.in_graph_mode(): + # Handle target tensors if any passed. + if not isinstance(y, (list, tuple)): + y = [y] + target_tensors = [v for v in y if tensor_util.is_tensor(v)] + else: + target_tensors = None + self.compile(optimizer=self.optimizer, + loss=self.loss, + metrics=self.metrics, + loss_weights=self.loss_weights, + target_tensors=target_tensors) + + # If `x` and `y` were all symbolic, then no model should not be fed any + # inputs and targets. + # Note: in this case, `any` and `all` are equivalent since we disallow + # mixed symbolic/value inputs. + if any(tensor_util.is_tensor(v) for v in all_inputs): + return [], [], [] + + # What follows is input validation and standardization to list format, + # in the case where all inputs are value arrays. + + if context.in_eager_mode(): + # In eager mode, do not do shape validation. + feed_input_names = self.input_names + feed_input_shapes = None + elif not self._is_graph_network: + # Case: symbolic-mode subclassed network. Do not do shape validation. + feed_input_names = self._feed_input_names + feed_input_shapes = None + else: + # Case: symbolic-mode graph network. + # In this case, we run extensive shape validation checks. + feed_input_names = self._feed_input_names + feed_input_shapes = self._feed_input_shapes + + # Standardize the inputs. x = _standardize_input_data( x, - self._feed_input_names, - self._feed_input_shapes, - check_batch_axis=False, + feed_input_names, + feed_input_shapes, + check_batch_axis=False, # Don't enforce the batch size. exception_prefix='input') - y = _standardize_input_data( - y, - self._feed_output_names, - output_shapes, - check_batch_axis=False, - exception_prefix='target') - sample_weights = _standardize_sample_weights(sample_weight, - self._feed_output_names) - class_weights = _standardize_class_weights(class_weight, - self._feed_output_names) - sample_weights = [ - _standardize_weights(ref, sw, cw, mode) - for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, - self._feed_sample_weight_modes) - ] - _check_array_lengths(x, y, sample_weights) - _check_loss_and_target_compatibility(y, self._feed_loss_fns, - self._feed_output_shapes) + + if y is not None: + if context.in_eager_mode(): + feed_output_names = self.output_names + feed_output_shapes = None + # Sample weighting not supported in this case. + # TODO(fchollet): consider supporting it. + feed_sample_weight_modes = [None for _ in self.outputs] + elif not self._is_graph_network: + feed_output_names = self._feed_output_names + feed_output_shapes = None + # Sample weighting not supported in this case. + # TODO(fchollet): consider supporting it. + feed_sample_weight_modes = [None for _ in self.outputs] + else: + feed_output_names = self._feed_output_names + feed_sample_weight_modes = self._feed_sample_weight_modes + feed_output_shapes = [] + for output_shape, loss_fn in zip(self._feed_output_shapes, + self._feed_loss_fns): + if loss_fn is losses.sparse_categorical_crossentropy: + feed_output_shapes.append(output_shape[:-1] + (1,)) + elif (not hasattr(loss_fn, '__name__') or + getattr(losses, loss_fn.__name__, None) is None): + # If `loss_fn` is not a function (e.g. callable class) + # or if it not in the `losses` module, then + # it is a user-defined loss and we make no assumptions + # about it. + feed_output_shapes.append(None) + else: + feed_output_shapes.append(output_shape) + + # Standardize the outputs. + y = _standardize_input_data( + y, + feed_output_names, + feed_output_shapes, + check_batch_axis=False, # Don't enforce the batch size. + exception_prefix='target') + + # Generate sample-wise weight values given the `sample_weight` and + # `class_weight` arguments. + sample_weights = _standardize_sample_weights(sample_weight, + feed_output_names) + class_weights = _standardize_class_weights(class_weight, + feed_output_names) + sample_weights = [ + _standardize_weights(ref, sw, cw, mode) + for (ref, sw, cw, mode) in zip(y, sample_weights, class_weights, + feed_sample_weight_modes) + ] + # Check that all arrays have the same length. + _check_array_lengths(x, y, sample_weights) + if self._is_graph_network and not context.in_eager_mode(): + # Additional checks to avoid users mistakenly using improper loss fns. + _check_loss_and_target_compatibility(y, self._feed_loss_fns, + feed_output_shapes) + else: + y = [] + sample_weights = [] + if self.stateful and batch_size: + # Check that for stateful networks, number of samples is a multiple + # of the static batch size. if x[0].shape[0] % batch_size != 0: raise ValueError('In a stateful network, ' 'you should only pass inputs with ' @@ -1538,19 +1766,151 @@ class Model(Network): str(x[0].shape[0]) + ' samples') return x, y, sample_weights - def _get_deduped_metrics_names(self): - out_labels = self.metrics_names + def _set_inputs(self, inputs, training=None): + """Set model's input and output specs based on the input data received. + + This is to be used for Model subclasses, which do not know at instantiation + time what their inputs look like. + + Args: + inputs: Single array, or list of arrays. The arrays could be placeholders, + Numpy arrays, or data tensors. + - if placeholders: the model is built on top of these placeholders, + and we expect Numpy data to be fed for them when calling `fit`/etc. + - if Numpy data: we create placeholders matching the shape of the Numpy + arrays. We expect Numpy data to be fed for these placeholders + when calling `fit`/etc. + - if data tensors: the model is built on top of these tensors. + We do not expect any Numpy data to be provided when calling `fit`/etc. + training: Boolean or None. Only relevant in symbolic mode. Specifies + whether to build the model's graph in inference mode (False), training + mode (True), or using the Keras learning phase (None). + """ + if context.in_eager_mode(): + self._eager_set_inputs(inputs) + else: + self._symbolic_set_inputs(inputs, training=training) + + def _eager_set_inputs(self, inputs): + """Set model's input and output specs based on the input data received. - # Rename duplicated metrics name - # (can happen with an output layer shared among multiple dataflows). - deduped_out_labels = [] - for i, label in enumerate(out_labels): - new_label = label - if out_labels.count(label) > 1: - dup_idx = out_labels[:i].count(label) - new_label += '_' + str(dup_idx + 1) - deduped_out_labels.append(new_label) - return deduped_out_labels + This is to be used for Model subclasses, which do not know at instantiation + time what their inputs look like. + + We assume the number and ndim of outputs + does not change over different calls. + + Args: + inputs: Argument `x` (input data) passed by the user upon first model use. + + Raises: + ValueError: If the model's inputs are already set. + """ + assert context.in_eager_mode() + if self.inputs: + raise ValueError('Model inputs are already set.') + # On-the-fly setting of model inputs/outputs as DeferredTensors, + # to keep track of number of inputs and outputs and their ndim. + if isinstance(inputs, (list, tuple)): + dummy_output_values = self.call( + [ops.convert_to_tensor(v, dtype=K.floatx()) for v in inputs]) + dummy_input_values = list(inputs) + else: + dummy_output_values = self.call( + ops.convert_to_tensor(inputs, dtype=K.floatx())) + dummy_input_values = [inputs] + if isinstance(dummy_output_values, (list, tuple)): + dummy_output_values = list(dummy_output_values) + else: + dummy_output_values = [dummy_output_values] + self.outputs = [ + _DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_output_values] + self.inputs = [ + _DeferredTensor(shape=(None for _ in v.shape), + dtype=v.dtype) for v in dummy_input_values] + self.input_names = [ + 'input_%d' % (i + 1) for i in range(len(dummy_input_values))] + self.output_names = [ + 'output_%d' % (i + 1) for i in range(len(dummy_output_values))] + self.built = True + + def _symbolic_set_inputs(self, inputs, training=None): + """Set model's inputs based on the input data received from the user. + + This is to be used for Model subclasses, which do not know at instantiation + time what their inputs look like. + + Args: + inputs: Argument `x` (input data) passed by the user upon first model use. + training: Boolean or None. Only relevant in symbolic mode. Specifies + whether to build the model's graph in inference mode (False), training + mode (True), or using the Keras learning phase (None). + + Raises: + ValueError: If the model's inputs are already set. + """ + assert context.in_graph_mode() + if self.inputs: + raise ValueError('Model inputs are already set.') + + # On-the-fly setting of symbolic model inputs (either by using the tensor + # provided, or by creating a placeholder if Numpy data was provided). + self.inputs = [] + self.input_names = [] + self._feed_inputs = [] + self._feed_input_names = [] + self._feed_input_shapes = [] + if isinstance(inputs, (list, tuple)): + inputs = list(inputs) + else: + inputs = [inputs] + + for i, v in enumerate(inputs): + name = 'input_%d' % (i + 1) + self.input_names.append(name) + if isinstance(v, list): + v = np.asarray(v) + if v.ndim == 1: + v = np.expand_dims(v, 1) + if isinstance(v, (np.ndarray)): + # We fix the placeholder shape except the batch size. + # This is suboptimal, but it is the best we can do with the info + # we have. The user should call `model._set_inputs(placeholders)` + # to specify custom placeholders if the need arises. + shape = (None,) + v.shape[1:] + placeholder = K.placeholder(shape=shape, name=name) + self.inputs.append(placeholder) + self._feed_inputs.append(placeholder) + self._feed_input_names.append(name) + self._feed_input_shapes.append(shape) + else: + # Assumed tensor - TODO(fchollet) additional type check? + self.inputs.append(v) + if K.is_placeholder(v): + self._feed_inputs.append(v) + self._feed_input_names.append(name) + self._feed_input_shapes.append(K.int_shape(v)) + + # Obtain symbolic outputs by calling the model. + if len(self.inputs) == 1: + if self._expects_training_arg: + outputs = self.call(self.inputs[0], training=training) + else: + outputs = self.call(self.inputs[0]) + else: + if self._expects_training_arg: + outputs = self.call(self.inputs, training=training) + else: + outputs = self.call(self.inputs) + if isinstance(outputs, (list, tuple)): + outputs = list(outputs) + else: + outputs = [outputs] + self.outputs = outputs + self.output_names = [ + 'output_%d' % (i + 1) for i in range(len(self.outputs))] + self.built = True def fit(self, x=None, @@ -1661,6 +2021,9 @@ class Model(Network): ValueError: In case of mismatch between the provided input data and what the model expects. """ + # TODO(fchollet): this method may be creating reference cycles, which would + # lead to accumulating garbage in memory when called in a loop. Investigate. + # Backwards compatibility if batch_size is None and steps_per_epoch is None: batch_size = 32 @@ -1676,13 +2039,13 @@ class Model(Network): raise ValueError('If fitting from data tensors, ' 'you should specify the `steps_per_epoch` ' 'argument.') + # Validate user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, class_weight=class_weight, - check_batch_axis=False, batch_size=batch_size) # Prepare validation data. do_validation = False @@ -1705,12 +2068,7 @@ class Model(Network): val_x, val_y, sample_weight=val_sample_weight, - check_batch_axis=False, batch_size=batch_size) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = val_x + val_y + val_sample_weights + [0.] - else: - val_ins = val_x + val_y + val_sample_weights elif validation_split and 0. < validation_split < 1.: do_validation = True @@ -1718,40 +2076,38 @@ class Model(Network): split_at = int(x[0].shape[0] * (1. - validation_split)) else: split_at = int(len(x[0]) * (1. - validation_split)) - x, val_x = (_slice_arrays(x, 0, split_at), _slice_arrays(x, split_at)) - y, val_y = (_slice_arrays(y, 0, split_at), _slice_arrays(y, split_at)) - sample_weights, val_sample_weights = (_slice_arrays( - sample_weights, 0, split_at), _slice_arrays(sample_weights, split_at)) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = val_x + val_y + val_sample_weights + [0.] - else: - val_ins = val_x + val_y + val_sample_weights - + x, val_x = (slice_arrays(x, 0, split_at), slice_arrays(x, split_at)) + y, val_y = (slice_arrays(y, 0, split_at), slice_arrays(y, split_at)) + sample_weights, val_sample_weights = (slice_arrays( + sample_weights, 0, split_at), slice_arrays(sample_weights, split_at)) elif validation_steps: + val_x = [] + val_y = [] + val_sample_weights = [] do_validation = True - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - val_ins = [0.] - - # Prepare input arrays and training function. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [1.] - else: - ins = x + y + sample_weights # Prepare display labels. - out_labels = self._get_deduped_metrics_names() + out_labels = self.metrics_names if context.in_eager_mode(): + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') + if do_validation: + if any([w is not None for w in val_sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported' + ' when eager execution is enabled, for now.') callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] + val_ins = val_x + val_y else: callback_metrics = copy.copy(out_labels) return training_eager.fit_loop( self, - ins, + x + y, out_labels=out_labels, batch_size=batch_size, epochs=epochs, @@ -1764,18 +2120,25 @@ class Model(Network): steps_per_epoch=steps_per_epoch, validation_steps=validation_steps) else: + # Prepare input arrays and training function. + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [1] + else: + ins = x + y + sample_weights + self._make_train_function() f = self.train_function if do_validation: - if context.in_graph_mode(): - self._make_test_function() - val_f = self.test_function - else: - val_f = None + self._make_test_function() + val_f = self.test_function callback_metrics = copy.copy(out_labels) + [ 'val_' + n for n in out_labels ] + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + val_ins = val_x + val_y + val_sample_weights + [0] + else: + val_ins = val_x + val_y + val_sample_weights else: val_f = None callback_metrics = copy.copy(out_labels) @@ -1859,23 +2222,27 @@ class Model(Network): raise ValueError('If evaluating from data tensors, ' 'you should specify the `steps` ' 'argument.') + # Validate user data. x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, - check_batch_axis=False, batch_size=batch_size) - # Prepare inputs, delegate logic to `_test_loop`. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [0.] - else: - ins = x + y + sample_weights if context.in_eager_mode(): + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') return training_eager.test_loop( - self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + self, x + y, batch_size=batch_size, verbose=verbose, steps=steps) else: + # Prepare inputs, delegate logic to `_test_loop`. + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [0] + else: + ins = x + y + sample_weights + self._make_test_function() f = self.test_function return self._test_loop( @@ -1911,31 +2278,18 @@ class Model(Network): raise ValueError('If predicting from data tensors, ' 'you should specify the `steps` ' 'argument.') - # Validate user data. - x = _standardize_input_data( - x, - self._feed_input_names, - self._feed_input_shapes, - check_batch_axis=False) - if self.stateful: - if x[0].shape[0] > batch_size and x[0].shape[0] % batch_size != 0: - raise ValueError('In a stateful network, ' - 'you should only pass inputs with ' - 'a number of samples that can be ' - 'divided by the batch size. Found: ' + - str(x[0].shape[0]) + ' samples. ' - 'Batch size: ' + str(batch_size) + '.') - - # Prepare inputs, delegate logic to `_predict_loop`. - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0.] - else: - ins = x + x, _, _ = self._standardize_user_data(x) if context.in_eager_mode(): return training_eager.predict_loop( - self, ins, batch_size=batch_size, verbose=verbose, steps=steps) + self, x, batch_size=batch_size, verbose=verbose, steps=steps) else: + # Prepare inputs, delegate logic to `_predict_loop`. + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + [0] + else: + ins = x + self._make_predict_function() f = self.predict_function @@ -1977,27 +2331,32 @@ class Model(Network): and/or metrics). The attribute `model.metrics_names` will give you the display labels for the scalar outputs. + Raises: + ValueError: In case of invalid user-provided arguments. """ x, y, sample_weights = self._standardize_user_data( x, y, sample_weight=sample_weight, - class_weight=class_weight, - check_batch_axis=True) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [1.] - else: - ins = x + y + sample_weights + class_weight=class_weight) if context.in_eager_mode(): - return training_eager.train_on_batch(self, ins) + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') + outputs = training_eager.train_on_batch(self, x + y) + else: + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [1] + else: + ins = x + y + sample_weights - if context.in_graph_mode(): self._make_train_function() outputs = self.train_function(ins) - if len(outputs) == 1: - return outputs[0] - return outputs + + if len(outputs) == 1: + return outputs[0] + return outputs def test_on_batch(self, x, y, sample_weight=None): """Test the model on a single batch of samples. @@ -2028,24 +2387,27 @@ class Model(Network): the display labels for the scalar outputs. Raises: - ValueError: in case of invalid arguments. + ValueError: In case of invalid user-provided arguments. """ x, y, sample_weights = self._standardize_user_data( - x, y, sample_weight=sample_weight, check_batch_axis=True) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + y + sample_weights + [0.] - else: - ins = x + y + sample_weights + x, y, sample_weight=sample_weight) if context.in_eager_mode(): - return training_eager.test_on_batch(self, ins) - - if context.in_graph_mode(): + if any([w is not None for w in sample_weights]): + raise ValueError('`sample_weight` and `class_weight` is not supported ' + 'when eager execution is enabled, for now.') + outputs = training_eager.test_on_batch(self, x + y) + else: + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + y + sample_weights + [0] + else: + ins = x + y + sample_weights self._make_test_function() outputs = self.test_function(ins) - if len(outputs) == 1: - return outputs[0] - return outputs + + if len(outputs) == 1: + return outputs[0] + return outputs def predict_on_batch(self, x): """Returns predictions for a single batch of samples. @@ -2057,16 +2419,11 @@ class Model(Network): Numpy array(s) of predictions. """ - x = _standardize_input_data(x, self._feed_input_names, - self._feed_input_shapes) - if self.uses_learning_phase and not isinstance(K.learning_phase(), int): - ins = x + [0.] - else: - ins = x + x, _, _ = self._standardize_user_data(x) if context.in_eager_mode(): ins_batch_converted = [] - for ib in ins: + for ib in x: ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) eager_model_inputs = [] @@ -2077,6 +2434,11 @@ class Model(Network): return outs if context.in_graph_mode(): + if self.uses_learning_phase and not isinstance(K.learning_phase(), int): + ins = x + [0] + else: + ins = x + self._make_predict_function() outputs = self.predict_function(ins) if len(outputs) == 1: @@ -2190,6 +2552,10 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ + if not self._is_graph_network: + raise NotImplementedError( + '`fit_generator` is not yet enabled for Model subclasses') + wait_time = 0.01 # in seconds epoch = initial_epoch @@ -2228,8 +2594,8 @@ class Model(Network): ' the `keras.utils.Sequence` class.') # Prepare display labels. - out_labels = self._get_deduped_metrics_names() - callback_metrics = out_labels + ['val_' + n for n in out_labels] + out_labels = self.metrics_names + callback_metrics = out_labels + ['val_%s' % n for n in out_labels] # prepare callbacks self.history = cbks.History() @@ -2290,7 +2656,7 @@ class Model(Network): val_data = val_x + val_y + val_sample_weights if self.uses_learning_phase and not isinstance( K.learning_phase(), int): - val_data += [0.] + val_data += [0] for cbk in callbacks: cbk.validation_data = val_data @@ -2445,6 +2811,10 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ + if not self._is_graph_network: + raise NotImplementedError( + '`evaluate_generator` is not yet enabled for Model subclasses') + self._make_test_function() steps_done = 0 @@ -2569,6 +2939,10 @@ class Model(Network): ValueError: In case the generator yields data in an invalid format. """ + if not self._is_graph_network: + raise NotImplementedError( + '`predict_generator` is not yet enabled for Model subclasses') + self._make_predict_function() steps_done = 0 diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py index 0a115969ca614d8d50a60f8980fa49bf404cc66f..cdf189adef7ad0d7d75752a1d6289bbfc048851a 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py @@ -26,69 +26,10 @@ from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras import callbacks as cbks from tensorflow.python.keras._impl.keras import losses from tensorflow.python.keras._impl.keras import metrics as metrics_module +from tensorflow.python.keras._impl.keras.utils.generic_utils import make_batches from tensorflow.python.keras._impl.keras.utils.generic_utils import Progbar - - -def _make_batches(size, batch_size): - """Returns a list of batch indices (tuples of indices). - - Arguments: - size: Integer, total size of the data to slice into batches. - batch_size: Integer, batch size. - - Returns: - A list of tuples of array indices. - """ - num_batches = int(np.ceil(size / float(batch_size))) - return [(i * batch_size, min(size, (i + 1) * batch_size)) - for i in range(0, num_batches)] - - -def _slice_arrays(arrays, start=None, stop=None): - """Slice an array or list of arrays. - - This takes an array-like, or a list of - array-likes, and outputs: - - arrays[start:stop] if `arrays` is an array-like - - [x[start:stop] for x in arrays] if `arrays` is a list - - Can also work on list/array of indices: `_slice_arrays(x, indices)` - - Arguments: - arrays: Single array or list of arrays. - start: can be an integer index (start index) - or a list/array of indices - stop: integer (stop index); should be None if - `start` was a list. - - Returns: - A slice of the array(s). - - Raises: - ValueError: If the value of start is a list and stop is not None. - """ - if arrays is None: - return [None] - if isinstance(start, list) and stop is not None: - raise ValueError('The stop argument has to be None if the value of start is' - 'a list.') - elif isinstance(arrays, list): - if hasattr(start, '__len__'): - # hdf5 datasets only support list objects as indices - if hasattr(start, 'shape'): - start = start.tolist() - return [None if x is None else x[start] for x in arrays] - else: - return [None if x is None else x[start:stop] for x in arrays] - else: - if hasattr(start, '__len__'): - if hasattr(start, 'shape'): - start = start.tolist() - return arrays[start] - elif hasattr(start, '__getitem__'): - return arrays[start:stop] - else: - return [None] +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays +from tensorflow.python.platform import tf_logging as logging def _get_metrics_info(metric, internal_output_shapes=None, loss_func=None): @@ -142,7 +83,7 @@ def _eager_metrics_fn(model, outputs, targets): output_metrics = model.nested_metrics[i] for nested_output_metric in output_metrics: metric_name, metric_fn = _get_metrics_info( - nested_output_metric, model._internal_output_shapes[i], + nested_output_metric, K.int_shape(model.outputs[i]), model.loss_functions[i]) if len(model.output_names) > 1: @@ -158,7 +99,7 @@ def _eager_metrics_fn(model, outputs, targets): return metric_names, metric_results -def _model_loss(model, inputs, targets): +def _model_loss(model, inputs, targets, training=False): """Calculates the loss for a given model. Arguments: @@ -166,6 +107,7 @@ def _model_loss(model, inputs, targets): inputs: The inputs of the given model. This is typically the mini batch of data that is fed to the model. targets: The predictions or targets of the given model. + training: Whether the model should be run in inference or training mode. Returns: Returns the model output, total loss and loss value calculated using the @@ -173,7 +115,16 @@ def _model_loss(model, inputs, targets): applies masking and sample weighting to the loss value. """ total_loss = 0 - outs = model(inputs) + if len(inputs) == 1: + if model._expects_training_arg: + outs = model.call(inputs[0], training=training) + else: + outs = model.call(inputs[0]) + else: + if model._expects_training_arg: + outs = model.call(inputs, training=training) + else: + outs = model.call(inputs) if not isinstance(outs, list): outs = [outs] @@ -188,6 +139,8 @@ def _model_loss(model, inputs, targets): model.output_names[i]) loss_metrics.append(K.mean(output_loss)) + # TODO(fchollet): support masking; in practice `_keras_mask` is never + # set in this context currently. mask = outs[i]._keras_mask # adapted from weighted_loss_fn if mask is not None: @@ -197,17 +150,7 @@ def _model_loss(model, inputs, targets): # to the number of unmasked samples. output_loss /= K.mean(mask) - # adapted from weighted_loss_fn - # apply sample weighting - if model.sample_weights: - # reduce score_array to same ndim as weight array - ndim = K.ndim(output_loss) - weight_ndim = K.ndim(model.sample_weights) - output_loss = K.mean(output_loss, axis=list(range(weight_ndim, ndim))) - output_loss *= model.sample_weights - output_loss /= K.mean(K.cast(K.not_equal(model.sample_weights, 0), - K.floatx())) - output_loss = K.mean(output_loss) + # TODO(fchollet): support sample weighting loss_weight = model.loss_weights_list[i] if total_loss is None: @@ -229,7 +172,7 @@ def _model_loss(model, inputs, targets): def _process_single_batch(eager_model_inputs, eager_model_outputs, model, - training=True): + training=False): """Calculate the loss and gradient for one input batch. The model weights are updated if training is set to True. @@ -246,24 +189,25 @@ def _process_single_batch(eager_model_inputs, eager_model_outputs, model, output of the model, total loss and the loss associated with each output. Raises: - ValueError: If the model loss is 0 or if the trainable weights list is - empty when the trainable parameter is set to True. + ValueError: If the model has no loss to optimize. """ K.set_learning_phase(training) with GradientTape() as tape: outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, - eager_model_outputs) + eager_model_outputs, + training=training) if loss is None: raise ValueError('The model cannot be run ' 'because it has no loss to optimize.') if training: if not model._collected_trainable_weights: - raise ValueError('The list of trainable weights is empty. Make sure that ' - 'you are not setting model.trainable to False before ' - 'compiling the model.') - grads = tape.gradient(loss, model._collected_trainable_weights) - model.optimizer.apply_gradients(zip(grads, - model._collected_trainable_weights)) + logging.warning('The list of trainable weights is empty. Make sure that ' + 'you are not setting model.trainable to False before ' + 'compiling the model.') + else: + grads = tape.gradient(loss, model._collected_trainable_weights) + model.optimizer.apply_gradients(zip(grads, + model._collected_trainable_weights)) return outs, loss, loss_metrics @@ -279,7 +223,8 @@ def train_on_batch(model, ins): """ ins_batch_converted = [] for ib in ins: - ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) + if ib is not None: + ins_batch_converted.append(ops.convert_to_tensor(ib, dtype=K.floatx())) eager_model_inputs = [] eager_model_outputs = [] for i in range(len(model.inputs)): @@ -287,7 +232,7 @@ def train_on_batch(model, ins): for i in range(len(model.inputs), len(ins_batch_converted)): eager_model_outputs.append(ins_batch_converted[i]) outs, loss, _ = _process_single_batch( - eager_model_inputs, eager_model_outputs, model) + eager_model_inputs, eager_model_outputs, model, training=True) if not isinstance(outs, list): outs = [outs] _, metrics_results = _eager_metrics_fn( @@ -439,16 +384,16 @@ def fit_loop( elif shuffle: np.random.shuffle(index_array) - batches = _make_batches(num_train_samples, batch_size) + batches = make_batches(num_train_samples, batch_size) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] try: if isinstance(ins[-1], float): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) except TypeError: raise TypeError('TypeError while preparing batch. ' 'If using HDF5 input data, ' @@ -472,7 +417,8 @@ def fit_loop( outs, loss, loss_metrics = _process_single_batch(eager_model_inputs, eager_model_outputs, - model) + model, + training=True) if not isinstance(outs, list): outs = [outs] @@ -551,15 +497,15 @@ def test_loop(model, ins, batch_size=None, verbose=0, steps=None): outs = [] if verbose == 1: progbar = Progbar(target=num_samples) - batches = _make_batches(num_samples, batch_size) + batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] if isinstance(ins[-1], float): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) ins_batch_converted = [] for ib in ins_batch: @@ -574,7 +520,8 @@ def test_loop(model, ins, batch_size=None, verbose=0, steps=None): eager_model_outputs.append(ins_batch_converted[i]) loss_outs, loss, loss_metrics = _model_loss(model, eager_model_inputs, - eager_model_outputs) + eager_model_outputs, + training=False) _, metrics_results = _eager_metrics_fn(model, loss_outs, eager_model_outputs) batch_outs = [] @@ -628,15 +575,15 @@ def predict_loop(model, ins, batch_size=32, verbose=0, steps=None): progbar = Progbar(target=num_samples) outs = [] - batches = _make_batches(num_samples, batch_size) + batches = make_batches(num_samples, batch_size) index_array = np.arange(num_samples) for batch_index, (batch_start, batch_end) in enumerate(batches): batch_ids = index_array[batch_start:batch_end] if ins and isinstance(ins[-1], float): # Do not slice the training phase flag. - ins_batch = _slice_arrays(ins[:-1], batch_ids) + [ins[-1]] + ins_batch = slice_arrays(ins[:-1], batch_ids) + [ins[-1]] else: - ins_batch = _slice_arrays(ins, batch_ids) + ins_batch = slice_arrays(ins, batch_ids) ins_batch_converted = [] for ib in ins_batch: @@ -646,7 +593,16 @@ def predict_loop(model, ins, batch_size=32, verbose=0, steps=None): for i in range(len(model.inputs)): eager_model_inputs.append(ins_batch_converted[i]) - batch_outs = model(eager_model_inputs) + if len(eager_model_inputs) == 1: + if model._expects_training_arg: + batch_outs = model.call(eager_model_inputs[0], training=False) + else: + batch_outs = model.call(eager_model_inputs[0]) + else: + if model._expects_training_arg: + batch_outs = model.call(eager_model_inputs, training=False) + else: + batch_outs = model.call(eager_model_inputs) if not isinstance(batch_outs, list): batch_outs = [batch_outs] diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py index 81e2f7a5145a586f6a4cc34f54033723fae6a6e9..550b86a71ddafed1f14ecd4a28ab652bb9e24154 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py @@ -309,438 +309,6 @@ class TrainingTest(test.TestCase): optimizer='rms') -class LossWeightingTest(test.TestCase): - - def test_class_weights(self): - num_classes = 5 - batch_size = 5 - epochs = 5 - weighted_class = 3 - train_samples = 3000 - test_samples = 3000 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - np.random.seed(1337) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_test = y_test.copy() - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) - test_ids = np.where(int_y_test == np.array(weighted_class))[0] - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - class_weight=class_weight, - validation_data=(x_train, y_train, sample_weight)) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 2, - verbose=0, - class_weight=class_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 2, - verbose=0, - class_weight=class_weight, - validation_split=0.1) - - model.train_on_batch( - x_train[:batch_size], y_train[:batch_size], class_weight=class_weight) - ref_score = model.evaluate(x_test, y_test, verbose=0) - score = model.evaluate( - x_test[test_ids, :], y_test[test_ids, :], verbose=0) - self.assertLess(score, ref_score) - - def test_sample_weights(self): - num_classes = 5 - batch_size = 5 - epochs = 5 - weighted_class = 3 - train_samples = 3000 - test_samples = 3000 - input_dim = 5 - - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_shape=(input_dim,))) - model.add(keras.layers.Activation('relu')) - model.add(keras.layers.Dense(num_classes)) - model.add(keras.layers.Activation('softmax')) - model.compile(loss='categorical_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - np.random.seed(43) - (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_test = y_test.copy() - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - y_test = keras.utils.to_categorical(y_test, num_classes) - test_ids = np.where(int_y_test == np.array(weighted_class))[0] - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - sample_weight=sample_weight) - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs // 3, - verbose=0, - sample_weight=sample_weight, - validation_split=0.1) - model.train_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - model.test_on_batch( - x_train[:batch_size], - y_train[:batch_size], - sample_weight=sample_weight[:batch_size]) - - def test_temporal_sample_weights(self): - num_classes = 5 - weighted_class = 3 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - - np.random.seed(1337) - (_, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - int_y_train = y_train.copy() - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - - class_weight = dict([(i, 1.) for i in range(num_classes)]) - class_weight[weighted_class] = 2. - - sample_weight = np.ones((y_train.shape[0])) - sample_weight[int_y_train == weighted_class] = 2. - with self.assertRaises(ValueError): - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - sample_weight_mode='temporal') - - def test_class_weight_invalid_use_case(self): - num_classes = 5 - train_samples = 1000 - test_samples = 1000 - input_dim = 5 - timesteps = 3 - - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(num_classes), - input_shape=(timesteps, input_dim))) - model.add(keras.layers.Activation('softmax')) - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001)) - - (x_train, y_train), _ = testing_utils.get_test_data( - train_samples=train_samples, - test_samples=test_samples, - input_shape=(input_dim,), - num_classes=num_classes) - # convert class vectors to binary class matrices - y_train = keras.utils.to_categorical(y_train, num_classes) - class_weight = dict([(i, 1.) for i in range(num_classes)]) - - del class_weight[1] - with self.assertRaises(ValueError): - model.fit(x_train, y_train, - epochs=0, verbose=0, class_weight=class_weight) - - with self.assertRaises(ValueError): - model.compile( - loss='binary_crossentropy', - optimizer=RMSPropOptimizer(learning_rate=0.001), - sample_weight_mode=[]) - - # Build multi-output model - x = keras.Input((3,)) - y1 = keras.layers.Dense(4, name='1')(x) - y2 = keras.layers.Dense(4, name='2')(x) - model = keras.models.Model(x, [y1, y2]) - model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse') - x_np = np.random.random((10, 3)) - y_np = np.random.random((10, 4)) - w_np = np.random.random((10,)) - # This will work - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': w_np}) - # These will not - with self.assertRaises(ValueError): - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=[w_np]) - with self.assertRaises(TypeError): - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight=w_np) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((11,)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - with self.assertRaises(ValueError): - bad_w_np = np.random.random((10, 2, 2)) - model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np}) - - -class TestDynamicTrainability(test.TestCase): - - def test_trainable_warning(self): - x = np.random.random((5, 3)) - y = np.random.random((5, 2)) - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=3)) - model.trainable = False - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - model.trainable = True - with self.assertRaises(ValueError): - model.train_on_batch(x, y) - - def test_trainable_argument(self): - x = np.random.random((5, 3)) - y = np.random.random((5, 2)) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=3, trainable=False)) - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - out = model.predict(x) - with self.assertRaises(ValueError): - model.train_on_batch(x, y) - out_2 = model.predict(x) - self.assertAllClose(out, out_2) - - # test with nesting - inputs = keras.layers.Input(shape=(3,)) - output = model(inputs) - model = keras.models.Model(inputs, output) - model.compile(RMSPropOptimizer(learning_rate=0.001), 'mse') - out = model.predict(x) - with self.assertRaises(ValueError): - model.train_on_batch(x, y) - out_2 = model.predict(x) - self.assertAllClose(out, out_2) - - def test_layer_trainability_switch(self): - # with constructor argument, in Sequential - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, trainable=False, input_dim=1)) - self.assertListEqual(model.trainable_weights, []) - - # by setting the `trainable` argument, in Sequential - model = keras.models.Sequential() - layer = keras.layers.Dense(2, input_dim=1) - model.add(layer) - self.assertListEqual(model.trainable_weights, layer.trainable_weights) - layer.trainable = False - self.assertListEqual(model.trainable_weights, []) - - # with constructor argument, in Model - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2, trainable=False)(x) - model = keras.models.Model(x, y) - self.assertListEqual(model.trainable_weights, []) - - # by setting the `trainable` argument, in Model - x = keras.layers.Input(shape=(1,)) - layer = keras.layers.Dense(2) - y = layer(x) - model = keras.models.Model(x, y) - self.assertListEqual(model.trainable_weights, layer.trainable_weights) - layer.trainable = False - self.assertListEqual(model.trainable_weights, []) - - def test_model_trainability_switch(self): - # a non-trainable model has no trainable weights - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - model = keras.models.Model(x, y) - model.trainable = False - self.assertListEqual(model.trainable_weights, []) - - # same for Sequential - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_dim=1)) - model.trainable = False - self.assertListEqual(model.trainable_weights, []) - - def test_nested_model_trainability(self): - - # a Sequential inside a Model - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(2, input_dim=1)) - - x = keras.layers.Input(shape=(1,)) - y = inner_model(x) - outer_model = keras.models.Model(x, y) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Sequential inside a Sequential - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(2, input_dim=1)) - outer_model = keras.models.Sequential() - outer_model.add(inner_model) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Model inside a Model - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - inner_model = keras.models.Model(x, y) - x = keras.layers.Input(shape=(1,)) - y = inner_model(x) - outer_model = keras.models.Model(x, y) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - # a Model inside a Sequential - x = keras.layers.Input(shape=(1,)) - y = keras.layers.Dense(2)(x) - inner_model = keras.models.Model(x, y) - outer_model = keras.models.Sequential() - outer_model.add(inner_model) - self.assertListEqual(outer_model.trainable_weights, - inner_model.trainable_weights) - inner_model.trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - inner_model.trainable = True - inner_model.layers[-1].trainable = False - self.assertListEqual(outer_model.trainable_weights, []) - - -class TestTrainingUtils(test.TestCase): - - def test_check_array_lengths(self): - keras.engine.training._check_array_lengths(None, None, None) - a_np = np.random.random((4, 3, 3)) - keras.engine.training._check_array_lengths(a_np, a_np, a_np) - keras.engine.training._check_array_lengths( - [a_np, a_np], [a_np, a_np], [a_np, a_np]) - keras.engine.training._check_array_lengths([None], [None], [None]) - - b_np = np.random.random((3, 4)) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, None, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, a_np, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [None], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [b_np], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], None, [b_np]) - - def test_slice_arrays(self): - input_a = np.random.random((10, 3)) - keras.engine.training._slice_arrays(None) - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) - input_a = [None, [1, 1], None, [1, 1]] - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) - input_a = [None] - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) - input_a = None - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) - - def test_fit_with_BatchNorm(self): - model = keras.models.Sequential() - model.add(keras.layers.Dense(10, input_dim=4)) - model.add(keras.layers.BatchNormalization()) - model.add(keras.layers.Activation('tanh')) - model.add(keras.layers.Dropout(0.2)) - - input_a_np = np.random.random((10, 4)) - output_b_np = np.random.random((10, 10)) - - model.compile(loss='binary_crossentropy', optimizer=RMSPropOptimizer(0.001)) - model.fit(input_a_np, output_b_np, epochs=1, batch_size=5, verbose=0) - - def test_fit_with_regularization(self): - model = keras.models.Sequential() - with self.assertRaises(ValueError): - model.add( - keras.layers.Dense(4, input_dim=3, - kernel_regularizer=keras.regularizers.l2(0.01), - activity_regularizer=keras.regularizers.l1(0.01))) - - if __name__ == '__main__': # Bazel sets these environment variables to very long paths. # Tempfile uses them to create long paths, and in turn multiprocessing diff --git a/tensorflow/python/keras/_impl/keras/engine/training_test.py b/tensorflow/python/keras/_impl/keras/engine/training_test.py index b380238e4e2bb3bccbfc5efdc0db213d86910fe5..6ca5941e9a339739a42452169615f9e0b14c79c0 100644 --- a/tensorflow/python/keras/_impl/keras/engine/training_test.py +++ b/tensorflow/python/keras/_impl/keras/engine/training_test.py @@ -26,6 +26,7 @@ import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.keras._impl.keras.engine.training import _weighted_masked_objective +from tensorflow.python.keras._impl.keras.utils.generic_utils import slice_arrays from tensorflow.python.platform import test try: @@ -1044,35 +1045,27 @@ class TestTrainingUtils(test.TestCase): keras.engine.training._check_array_lengths([None], [None], [None]) b_np = np.random.random((3, 4)) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, None, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths(a_np, a_np, None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], [None], None) with self.assertRaises(ValueError): keras.engine.training._check_array_lengths([a_np], [b_np], None) - with self.assertRaises(ValueError): - keras.engine.training._check_array_lengths([a_np], None, [b_np]) def test_slice_arrays(self): input_a = np.random.random((10, 3)) - keras.engine.training._slice_arrays(None) - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(None) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) input_a = [None, [1, 1], None, [1, 1]] - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) input_a = [None] - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) input_a = None - keras.engine.training._slice_arrays(input_a, 0) - keras.engine.training._slice_arrays(input_a, 0, 1) - keras.engine.training._slice_arrays(input_a, stop=2) + slice_arrays(input_a, 0) + slice_arrays(input_a, 0, 1) + slice_arrays(input_a, stop=2) class TestTrainingWithDataTensors(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/estimator.py b/tensorflow/python/keras/_impl/keras/estimator.py index db0140c2df4d20f9e18e6c1401c6c6aa197bcf1f..0bf5bd41dc915fbecbce4c3a6191e925612dbebb 100644 --- a/tensorflow/python/keras/_impl/keras/estimator.py +++ b/tensorflow/python/keras/_impl/keras/estimator.py @@ -222,18 +222,18 @@ def _save_first_checkpoint(keras_model, estimator, custom_objects, Returns: The model_fn for a keras Estimator. """ - with ops.Graph().as_default() as g, g.device(estimator._device_fn): - random_seed.set_random_seed(estimator.config.tf_random_seed) - training_util.create_global_step() - model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, - custom_objects) - - if isinstance(model, models.Sequential): - model = model.model - # Load weights and save to checkpoint if there is no checkpoint - latest_path = saver_lib.latest_checkpoint(estimator.model_dir) - if not latest_path: - with session.Session() as sess: + # Load weights and save to checkpoint if there is no checkpoint + latest_path = saver_lib.latest_checkpoint(estimator.model_dir) + if not latest_path: + with ops.Graph().as_default(): + random_seed.set_random_seed(estimator.config.tf_random_seed) + training_util.create_global_step() + model = _clone_and_build_model(model_fn_lib.ModeKeys.TRAIN, keras_model, + custom_objects) + if isinstance(model, models.Sequential): + model = model.model + # save to checkpoint + with session.Session(config=estimator._session_config) as sess: model.set_weights(keras_weights) # Make update ops and initialize all variables. if not model.train_function: diff --git a/tensorflow/python/keras/_impl/keras/estimator_test.py b/tensorflow/python/keras/_impl/keras/estimator_test.py index 9fc48b4117e7ee2c717d5418754254aa02b82869..88dd14b856a4ee9dfbee61d6fd1bdb96af24b50c 100644 --- a/tensorflow/python/keras/_impl/keras/estimator_test.py +++ b/tensorflow/python/keras/_impl/keras/estimator_test.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import json from math import log10 import os import tempfile @@ -62,7 +63,7 @@ def simple_functional_model(): return model -def get_resource_for_simple_model(is_sequential, is_evaluate): +def get_resource_for_simple_model(is_sequential=True, is_evaluate=False): model = simple_sequential_model( ) if is_sequential else simple_functional_model() if is_sequential: @@ -352,6 +353,30 @@ class TestKerasEstimator(test_util.TensorFlowTestCase): model_dir=tempfile.mkdtemp(dir=self._base_dir), custom_objects=custom_objects) + def test_tf_config(self): + keras_model, (_, _), (_, _), _, _ = get_resource_for_simple_model() + keras_model.compile( + loss='categorical_crossentropy', + optimizer='rmsprop', + metrics=['mse', keras.metrics.categorical_accuracy]) + + tf_config = json.dumps({ + 'cluster': { + run_config_lib.TaskType.PS: ['localhost:1234'], + run_config_lib.TaskType.WORKER: ['localhost:1236'], + run_config_lib.TaskType.MASTER: ['localhost:1238'] + }, + 'task': { + 'type': run_config_lib.TaskType.MASTER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + with self.test_session(): + keras.estimator.model_to_estimator( + keras_model=keras_model, + model_dir=tempfile.mkdtemp(dir=self._base_dir)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/initializers.py b/tensorflow/python/keras/_impl/keras/initializers.py index 338c669f97736ace721f1d7e47a79426713ccfce..300bed5e1437074d010760c427c14f68e58ac363 100644 --- a/tensorflow/python/keras/_impl/keras/initializers.py +++ b/tensorflow/python/keras/_impl/keras/initializers.py @@ -209,4 +209,5 @@ def get(identifier): elif callable(identifier): return identifier else: - raise ValueError('Could not interpret initializer identifier:', identifier) + raise ValueError('Could not interpret initializer identifier: ' + + str(identifier)) diff --git a/tensorflow/python/keras/_impl/keras/integration_test.py b/tensorflow/python/keras/_impl/keras/integration_test.py index 15c3d14727a44c9726a1c2c86f47640bcc490e70..280f7ed1b11e2026ac196eb319f7d5da8301f060 100644 --- a/tensorflow/python/keras/_impl/keras/integration_test.py +++ b/tensorflow/python/keras/_impl/keras/integration_test.py @@ -23,7 +23,6 @@ import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.layers import core as tf_core_layers -from tensorflow.python.layers import network as tf_network_layers from tensorflow.python.ops import nn from tensorflow.python.platform import test @@ -275,10 +274,10 @@ class KerasIntegrationTest(test.TestCase): y_train = keras.utils.to_categorical(y_train) y_test = keras.utils.to_categorical(y_test) - inputs = tf_network_layers.Input(shape=(10,)) + inputs = keras.Input(shape=(10,)) x = tf_core_layers.Dense(32, activation=nn.relu)(inputs) outputs = tf_core_layers.Dense(2, activation=nn.softmax)(x) - model = keras.models.Model(inputs, outputs) + model = keras.Model(inputs, outputs) model.summary() model.compile(loss='categorical_crossentropy', diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py index 7cac17c51a9adcf8fc62154b6633de60bab18387..c40ee109aaea7dacea72e095b1d8cea3ed2e9bf8 100644 --- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py +++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py @@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py index bc43451114a0c2396b687a7734bb48391139a914..162ae6c28f1afae1dd8aaf70213b808d9ad9598f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py @@ -60,7 +60,7 @@ class Conv1D(tf_convolutional_layers.Conv1D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of a single integer, specifying the length of the 1D convolution window. strides: An integer or tuple/list of a single integer, @@ -173,7 +173,7 @@ class Conv2D(tf_convolutional_layers.Conv2D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for @@ -308,7 +308,7 @@ class Conv3D(tf_convolutional_layers.Conv3D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 3 integers, specifying the depth, height and width of the 3D convolution window. Can be a single integer to specify the same value for @@ -877,7 +877,7 @@ class SeparableConv2D(tf_convolutional_layers.SeparableConv2D, Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py index a04c3a24bfb1d7b4dc6e388ebee14147b3f89461..d95a0942452afa82e277c358be5c3b2ba061ac98 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py @@ -26,7 +26,7 @@ from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.layers.recurrent import Recurrent from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.util.tf_export import tf_export @@ -39,7 +39,7 @@ class ConvRecurrent2D(Recurrent): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of n integers, specifying the dimensions of the convolution window. strides: An integer or tuple/list of n integers, @@ -200,7 +200,7 @@ class ConvLSTM2D(ConvRecurrent2D): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of n integers, specifying the dimensions of the convolution window. strides: An integer or tuple/list of n integers, diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py index 39c9d4f0fb2751b0eef3b28f6d5b8cb0a93e22e5..c612e97a9d67f7398c78a7da1107f8e067bf9371 100644 --- a/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_test.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import copy + import numpy as np +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -27,45 +31,40 @@ from tensorflow.python.platform import test class Convolution1DTest(test.TestCase): - def test_dilated_conv1d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv1D, - input_data=np.reshape(np.arange(4, dtype='float32'), (1, 4, 1)), - kwargs={ - 'filters': 1, - 'kernel_size': 2, - 'dilation_rate': 1, - 'padding': 'valid', - 'kernel_initializer': 'ones', - 'use_bias': False, - }, - expected_output=[[[1], [3], [5]]]) - - def test_conv_1d(self): - batch_size = 2 - steps = 8 - input_dim = 2 - kernel_size = 3 - filters = 3 + def _run_test(self, kwargs, arg, values): + num_samples = 2 + stack_size = 3 + length = 7 - for padding in ['valid', 'same']: - for strides in [1, 2]: - if padding == 'same' and strides != 1: - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv1D, + kwargs=test_kwargs, + input_shape=(num_samples, length, stack_size)) + + @tf_test_util.run_in_graph_and_eager_modes() + def test_conv1d(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [2]) + self._run_test(kwargs, 'dilation_rate', [2]) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv1D, - kwargs={ - 'filters': filters, - 'kernel_size': kernel_size, - 'padding': padding, - 'strides': strides - }, - input_shape=(batch_size, steps, input_dim)) - - def test_conv_1d_regularizers(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) + self._run_test(kwargs, 'dilation_rate', [3]) + + def test_conv1d_regularizers(self): kwargs = { 'filters': 3, 'kernel_size': 3, @@ -82,7 +81,7 @@ class Convolution1DTest(test.TestCase): layer(keras.backend.variable(np.ones((1, 5, 2)))) self.assertEqual(len(layer.losses), 3) - def test_conv_1d_constraints(self): + def test_conv1d_constraints(self): k_constraint = lambda x: x b_constraint = lambda x: x @@ -103,35 +102,44 @@ class Convolution1DTest(test.TestCase): class Conv2DTest(test.TestCase): - def test_convolution_2d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 - kernel_size = (3, 2) num_row = 7 num_col = 6 - for padding in ['valid', 'same']: - for strides in [(1, 1), (2, 2)]: - if padding == 'same' and strides != (1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.SeparableConv2D, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) + + @tf_test_util.run_in_graph_and_eager_modes() + def test_conv2d(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3), + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2)]) + if test.is_gpu_available(cuda_only=True): + # Only runs on GPU with CUDA, channels_first is not supported on CPU. + # TODO(b/62340061): Support channels_first on CPU. + self._run_test(kwargs, 'data_format', ['channels_first']) + self._run_test(kwargs, 'dilation_rate', [(2, 2)]) - with self.test_session(use_gpu=True): - # Only runs on GPU with CUDA, channels_first is not supported on CPU. - # TODO(b/62340061): Support channels_first on CPU. - if test.is_gpu_available(cuda_only=True): - testing_utils.layer_test( - keras.layers.Conv2D, - kwargs={ - 'filters': filters, - 'kernel_size': kernel_size, - 'padding': padding, - 'strides': strides, - 'data_format': 'channels_first' - }, - input_shape=(num_samples, stack_size, num_row, num_col)) - - def test_convolution_2d_regularizers(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) + + def test_conv2d_regularizers(self): kwargs = { 'filters': 3, 'kernel_size': 3, @@ -148,7 +156,7 @@ class Conv2DTest(test.TestCase): layer(keras.backend.variable(np.ones((1, 5, 5, 2)))) self.assertEqual(len(layer.losses), 3) - def test_convolution_2d_constraints(self): + def test_conv2d_constraints(self): k_constraint = lambda x: x b_constraint = lambda x: x @@ -166,51 +174,35 @@ class Conv2DTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) - def test_dilated_conv_2d(self): - num_samples = 2 - filters = 2 - stack_size = 3 - kernel_size = (3, 2) - num_row = 7 - num_col = 6 - - # Test dilation - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv2D, - kwargs={ - 'filters': filters, - 'kernel_size': kernel_size, - 'dilation_rate': (2, 2) - }, - input_shape=(num_samples, num_row, num_col, stack_size)) - class Conv2DTransposeTest(test.TestCase): - def test_conv2d_transpose(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 - num_row = 5 + num_row = 7 num_col = 6 - for padding in ['valid', 'same']: - for strides in [(1, 1), (2, 2)]: - if padding == 'same' and strides != (1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv2DTranspose, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv2DTranspose, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides, - 'data_format': 'channels_last' - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() + def test_conv2dtranspose(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3), + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) def test_conv2dtranspose_regularizers(self): kwargs = { @@ -250,30 +242,33 @@ class Conv2DTransposeTest(test.TestCase): class Conv3DTransposeTest(test.TestCase): - def test_conv3d_transpose(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 - num_row = 5 + num_row = 7 num_col = 6 - depth = 4 + depth = 5 - for padding in ['valid', 'same']: - for strides in [(1, 1, 1), (2, 2, 2)]: - if padding == 'same' and strides != (1, 1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv3DTranspose, + kwargs=test_kwargs, + input_shape=(num_samples, depth, num_row, num_col, stack_size)) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Conv3DTranspose, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides, - 'data_format': 'channels_last' - }, - input_shape=(num_samples, depth, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() + def test_conv3dtranspose(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3, 3), + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) def test_conv3dtranspose_regularizers(self): kwargs = { @@ -313,29 +308,38 @@ class Conv3DTransposeTest(test.TestCase): class SeparableConv1DTest(test.TestCase): - def test_separable_conv_1d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 6 stack_size = 3 length = 7 - strides = 1 - for padding in ['valid', 'same']: - for multiplier in [1, 2]: - if padding == 'same' and strides != 1: - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.SeparableConv1D, + kwargs=test_kwargs, + input_shape=(num_samples, length, stack_size)) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.SeparableConv1D, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides, - 'depth_multiplier': multiplier - }, - input_shape=(num_samples, length, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() + def test_separable_conv1d(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [2]) + self._run_test(kwargs, 'dilation_rate', [2]) + self._run_test(kwargs, 'depth_multiplier', [2]) + + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) def test_separable_conv1d_regularizers(self): kwargs = { @@ -379,30 +383,41 @@ class SeparableConv1DTest(test.TestCase): class SeparableConv2DTest(test.TestCase): - def test_separable_conv_2d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 6 stack_size = 3 num_row = 7 num_col = 6 - for padding in ['valid', 'same']: - for strides in [(1, 1), (2, 2)]: - for multiplier in [1, 2]: - if padding == 'same' and strides != (1, 1): - continue + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.SeparableConv2D, + kwargs=test_kwargs, + input_shape=(num_samples, num_row, num_col, stack_size)) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.SeparableConv2D, - kwargs={ - 'filters': filters, - 'kernel_size': (3, 3), - 'padding': padding, - 'strides': strides, - 'depth_multiplier': multiplier - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + @tf_test_util.run_in_graph_and_eager_modes() + def test_separable_conv2d(self): + kwargs = { + 'filters': 2, + 'kernel_size': 3, + } + + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [2]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) + self._run_test(kwargs, 'dilation_rate', [2]) + self._run_test(kwargs, 'depth_multiplier', [2]) + + kwargs = { + 'filters': 2, + 'kernel_size': 3, + 'padding': 'same', + } + self._run_test(kwargs, 'dilation_rate', [2]) def test_separable_conv2d_regularizers(self): kwargs = { @@ -446,33 +461,36 @@ class SeparableConv2DTest(test.TestCase): class Conv3DTest(test.TestCase): - def test_convolution_3d(self): + def _run_test(self, kwargs, arg, values): num_samples = 2 - filters = 2 stack_size = 3 + num_row = 7 + num_col = 6 + depth = 5 - input_len_dim1 = 9 - input_len_dim2 = 8 - input_len_dim3 = 8 + test_kwargs = copy.copy(kwargs) + for value in values: + test_kwargs[arg] = value + with self.test_session(use_gpu=True): + testing_utils.layer_test( + keras.layers.Conv3D, + kwargs=test_kwargs, + input_shape=(num_samples, depth, num_row, num_col, stack_size)) - for padding in ['valid', 'same']: - for strides in [(1, 1, 1), (2, 2, 2)]: - if padding == 'same' and strides != (1, 1, 1): - continue + @tf_test_util.run_in_graph_and_eager_modes() + def test_conv3d(self): + kwargs = { + 'filters': 2, + 'kernel_size': (3, 3, 3), + } - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.Convolution3D, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'strides': strides - }, - input_shape=(num_samples, input_len_dim1, input_len_dim2, - input_len_dim3, stack_size)) - - def test_convolution_3d_regularizers(self): + self._run_test(kwargs, 'padding', ['valid', 'same']) + self._run_test(kwargs, 'strides', [(2, 2, 2)]) + self._run_test(kwargs, 'dilation_rate', [(2, 2, 2)]) + if test.is_gpu_available(cuda_only=True): + self._run_test(kwargs, 'data_format', ['channels_first']) + + def test_conv3d_regularizers(self): kwargs = { 'filters': 3, 'kernel_size': 3, @@ -490,7 +508,7 @@ class Conv3DTest(test.TestCase): layer(keras.backend.variable(np.ones((1, 5, 5, 5, 2)))) self.assertEqual(len(layer.losses), 3) - def test_convolution_3d_constraints(self): + def test_conv3d_constraints(self): k_constraint = lambda x: x b_constraint = lambda x: x @@ -511,6 +529,7 @@ class Conv3DTest(test.TestCase): class ZeroPaddingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_zero_padding_1d(self): num_samples = 2 input_dim = 2 @@ -534,7 +553,10 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding1D(padding=2) layer.build(shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) for offset in [0, 1, -1, -2]: np.testing.assert_allclose(np_output[:, offset, :], 0.) np.testing.assert_allclose(np_output[:, 2:-2, :], 1.) @@ -542,7 +564,10 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding1D(padding=(1, 2)) layer.build(shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) for left_offset in [0]: np.testing.assert_allclose(np_output[:, left_offset, :], 0.) for right_offset in [-1, -2]: @@ -556,6 +581,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding1D(padding=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_zero_padding_2d(self): num_samples = 2 stack_size = 2 @@ -584,7 +610,10 @@ class ZeroPaddingTest(test.TestCase): padding=(2, 2), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_last': for offset in [0, 1, -1, -2]: np.testing.assert_allclose(np_output[:, offset, :, :], 0.) @@ -600,7 +629,10 @@ class ZeroPaddingTest(test.TestCase): padding=((1, 2), (3, 4)), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_last': for top_offset in [0]: np.testing.assert_allclose(np_output[:, top_offset, :, :], 0.) @@ -628,6 +660,7 @@ class ZeroPaddingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.ZeroPadding2D(padding=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_zero_padding_3d(self): num_samples = 2 stack_size = 2 @@ -650,7 +683,10 @@ class ZeroPaddingTest(test.TestCase): layer = keras.layers.ZeroPadding3D(padding=(2, 2, 2)) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) for offset in [0, 1, -1, -2]: np.testing.assert_allclose(np_output[:, offset, :, :, :], 0.) np.testing.assert_allclose(np_output[:, :, offset, :, :], 0.) @@ -666,11 +702,13 @@ class ZeroPaddingTest(test.TestCase): class UpSamplingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_upsampling_1d(self): with self.test_session(use_gpu=True): testing_utils.layer_test( keras.layers.UpSampling1D, kwargs={'size': 2}, input_shape=(3, 5, 4)) + @tf_test_util.run_in_graph_and_eager_modes() def test_upsampling_2d(self): num_samples = 2 stack_size = 2 @@ -699,7 +737,10 @@ class UpSamplingTest(test.TestCase): size=(length_row, length_col), data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_first': assert np_output.shape[2] == length_row * input_num_row assert np_output.shape[3] == length_col * input_num_col @@ -717,6 +758,7 @@ class UpSamplingTest(test.TestCase): np.testing.assert_allclose(np_output, expected_out) + @tf_test_util.run_in_graph_and_eager_modes() def test_upsampling_3d(self): num_samples = 2 stack_size = 2 @@ -748,7 +790,10 @@ class UpSamplingTest(test.TestCase): data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) if data_format == 'channels_first': assert np_output.shape[2] == length_dim1 * input_len_dim1 assert np_output.shape[3] == length_dim2 * input_len_dim2 @@ -773,6 +818,7 @@ class UpSamplingTest(test.TestCase): class CroppingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_cropping_1d(self): num_samples = 2 time_length = 4 @@ -791,6 +837,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping1D(cropping=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_cropping_2d(self): num_samples = 2 stack_size = 2 @@ -818,7 +865,10 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) # compare with numpy if data_format == 'channels_first': expected_out = inputs[:, :, cropping[0][0]:-cropping[0][1], cropping[ @@ -842,7 +892,10 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) # compare with input np.testing.assert_allclose(np_output, inputs) @@ -852,6 +905,7 @@ class CroppingTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.Cropping2D(cropping=None) + @tf_test_util.run_in_graph_and_eager_modes() def test_cropping_3d(self): num_samples = 2 stack_size = 2 @@ -883,7 +937,10 @@ class CroppingTest(test.TestCase): cropping=cropping, data_format=data_format) layer.build(inputs.shape) output = layer(keras.backend.variable(inputs)) - np_output = keras.backend.eval(output) + if context.in_eager_mode(): + np_output = output.numpy() + else: + np_output = keras.backend.eval(output) # compare with numpy if data_format == 'channels_first': expected_out = inputs[:, :, diff --git a/tensorflow/python/keras/_impl/keras/layers/core_test.py b/tensorflow/python/keras/_impl/keras/layers/core_test.py index bdb99c91c289cf808fec7b891376dbfcf5504aca..2ca816adbdcecaf371776d99f3da60d0d8790832 100644 --- a/tensorflow/python/keras/_impl/keras/layers/core_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/core_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import numpy as np -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils -from tensorflow.python.ops import init_ops from tensorflow.python.platform import test @@ -52,146 +50,134 @@ class CoreLayersTest(test.TestCase): dropout = keras.layers.Dropout(0.5) self.assertEqual(True, dropout.supports_masking) - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout1D, - kwargs={'rate': 0.5}, - input_shape=(2, 3, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout2D, - kwargs={'rate': 0.5}, - input_shape=(2, 3, 4, 5)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout2D, - kwargs={'rate': 0.5, 'data_format': 'channels_first'}, - input_shape=(2, 3, 4, 5)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout3D, - kwargs={'rate': 0.5}, - input_shape=(2, 3, 4, 4, 5)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.SpatialDropout3D, - kwargs={'rate': 0.5, 'data_format': 'channels_first'}, - input_shape=(2, 3, 4, 4, 5)) - + @tf_test_util.run_in_graph_and_eager_modes() + def test_spatial_dropout(self): + testing_utils.layer_test( + keras.layers.SpatialDropout1D, + kwargs={'rate': 0.5}, + input_shape=(2, 3, 4)) + + testing_utils.layer_test( + keras.layers.SpatialDropout2D, + kwargs={'rate': 0.5}, + input_shape=(2, 3, 4, 5)) + + testing_utils.layer_test( + keras.layers.SpatialDropout2D, + kwargs={'rate': 0.5, 'data_format': 'channels_first'}, + input_shape=(2, 3, 4, 5)) + + testing_utils.layer_test( + keras.layers.SpatialDropout3D, + kwargs={'rate': 0.5}, + input_shape=(2, 3, 4, 4, 5)) + + testing_utils.layer_test( + keras.layers.SpatialDropout3D, + kwargs={'rate': 0.5, 'data_format': 'channels_first'}, + input_shape=(2, 3, 4, 4, 5)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_activation(self): # with string argument - with self.test_session(): - testing_utils.layer_test( - keras.layers.Activation, - kwargs={'activation': 'relu'}, - input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.Activation, + kwargs={'activation': 'relu'}, + input_shape=(3, 2)) # with function argument - with self.test_session(): - testing_utils.layer_test( - keras.layers.Activation, - kwargs={'activation': keras.backend.relu}, - input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.Activation, + kwargs={'activation': keras.backend.relu}, + input_shape=(3, 2)) + @tf_test_util.run_in_graph_and_eager_modes() def test_reshape(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (8, 1)}, - input_shape=(3, 2, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (-1, 1)}, - input_shape=(3, 2, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (1, -1)}, - input_shape=(3, 2, 4)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Reshape, - kwargs={'target_shape': (-1, 1)}, - input_shape=(None, None, 2)) - + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (8, 1)}, + input_shape=(3, 2, 4)) + + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(3, 2, 4)) + + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (1, -1)}, + input_shape=(3, 2, 4)) + + testing_utils.layer_test( + keras.layers.Reshape, + kwargs={'target_shape': (-1, 1)}, + input_shape=(None, None, 2)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_permute(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) + testing_utils.layer_test( + keras.layers.Permute, kwargs={'dims': (2, 1)}, input_shape=(3, 2, 4)) + @tf_test_util.run_in_graph_and_eager_modes() def test_flatten(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) + testing_utils.layer_test( + keras.layers.Flatten, kwargs={}, input_shape=(3, 2, 4)) + @tf_test_util.run_in_graph_and_eager_modes() def test_repeat_vector(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.RepeatVector, kwargs={'n': 3}, input_shape=(3, 2)) + @tf_test_util.run_in_graph_and_eager_modes() def test_lambda(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Lambda, - kwargs={'function': lambda x: x + 1}, - input_shape=(3, 2)) - - with self.test_session(): - testing_utils.layer_test( - keras.layers.Lambda, - kwargs={ - 'function': lambda x, a, b: x * a + b, - 'arguments': { - 'a': 0.6, - 'b': 0.4 - } - }, - input_shape=(3, 2)) - - with self.test_session(): - # test serialization with function - def f(x): - return x + 1 - - ld = keras.layers.Lambda(f) - config = ld.get_config() - ld = keras.layers.deserialize({ - 'class_name': 'Lambda', - 'config': config - }) - - # test with lambda - ld = keras.layers.Lambda( - lambda x: keras.backend.concatenate([keras.backend.square(x), x])) - config = ld.get_config() - ld = keras.layers.Lambda.from_config(config) - + testing_utils.layer_test( + keras.layers.Lambda, + kwargs={'function': lambda x: x + 1}, + input_shape=(3, 2)) + + testing_utils.layer_test( + keras.layers.Lambda, + kwargs={ + 'function': lambda x, a, b: x * a + b, + 'arguments': { + 'a': 0.6, + 'b': 0.4 + } + }, + input_shape=(3, 2)) + + # test serialization with function + def f(x): + return x + 1 + + ld = keras.layers.Lambda(f) + config = ld.get_config() + ld = keras.layers.deserialize({ + 'class_name': 'Lambda', + 'config': config + }) + + # test with lambda + ld = keras.layers.Lambda( + lambda x: keras.backend.concatenate([keras.backend.square(x), x])) + config = ld.get_config() + ld = keras.layers.Lambda.from_config(config) + + @tf_test_util.run_in_graph_and_eager_modes() def test_dense(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 2)) - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 2)) - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(None, None, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(None, None, 2)) - with self.test_session(): - testing_utils.layer_test( - keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2)) + testing_utils.layer_test( + keras.layers.Dense, kwargs={'units': 3}, input_shape=(3, 4, 5, 2)) - # Test regularization + def test_dense_regularization(self): with self.test_session(): layer = keras.layers.Dense( 3, @@ -202,7 +188,7 @@ class CoreLayersTest(test.TestCase): layer(keras.backend.variable(np.ones((2, 4)))) self.assertEqual(3, len(layer.losses)) - # Test constraints + def test_dense_constraints(self): with self.test_session(): k_constraint = keras.constraints.max_norm(0.01) b_constraint = keras.constraints.max_norm(0.01) @@ -212,12 +198,6 @@ class CoreLayersTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) - def test_eager_dense(self): - with context.eager_mode(): - l = keras.layers.Dense(units=3, - kernel_initializer=init_ops.zeros_initializer()) - self.assertAllEqual(l(constant_op.constant([[1.0]])), [[0., 0., 0.]]) - def test_activity_regularization(self): with self.test_session(): layer = keras.layers.ActivityRegularization(l1=0.1) diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py index ca92899a455cd28a756e9efff63655d7c43c9f45..006ecd3135be25d43133daed1603734ecd1be955 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py @@ -23,7 +23,7 @@ from tensorflow.python.keras._impl.keras import constraints from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py index 1712111b877cf1fee4353c5542f33a973a26de95..26fd1f1c114587c2f1b3e0155f1259dd5f0dcf60 100644 --- a/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/embeddings_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -25,47 +26,44 @@ from tensorflow.python.platform import test class EmbeddingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_embedding(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'input_length': 2}, - input_shape=(3, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'input_length': 2}, + input_shape=(3, 2), + input_dtype='int32', + expected_output_dtype='float32') - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'mask_zero': True}, - input_shape=(3, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'mask_zero': True}, + input_shape=(3, 2), + input_dtype='int32', + expected_output_dtype='float32') - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'mask_zero': True}, - input_shape=(3, 4, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'mask_zero': True}, + input_shape=(3, 4, 2), + input_dtype='int32', + expected_output_dtype='float32') - with self.test_session(): - testing_utils.layer_test( - keras.layers.Embedding, - kwargs={'output_dim': 4, - 'input_dim': 10, - 'mask_zero': True, - 'input_length': (None, 2)}, - input_shape=(3, 4, 2), - input_dtype='int32', - expected_output_dtype='float32') + testing_utils.layer_test( + keras.layers.Embedding, + kwargs={'output_dim': 4, + 'input_dim': 10, + 'mask_zero': True, + 'input_length': (None, 2)}, + input_shape=(3, 4, 2), + input_dtype='int32', + expected_output_dtype='float32') if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/gru_test.py b/tensorflow/python/keras/_impl/keras/layers/gru_test.py index c57fbac41cc43995ef3249414ed03928e7ffd044..48e7e14f5ab73b534ab0d1c765ad2572b2930b2b 100644 --- a/tensorflow/python/keras/_impl/keras/layers/gru_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/gru_test.py @@ -20,64 +20,66 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class GRULayerTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_return_sequences_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.GRU, - kwargs={'units': units, - 'return_sequences': True}, - input_shape=(num_samples, timesteps, embedding_dim)) + testing_utils.layer_test( + keras.layers.GRU, + kwargs={'units': units, + 'return_sequences': True}, + input_shape=(num_samples, timesteps, embedding_dim)) + @tf_test_util.run_in_graph_and_eager_modes() def test_dynamic_behavior_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - layer = keras.layers.GRU(units, input_shape=(None, embedding_dim)) - model = keras.models.Sequential() - model.add(layer) - model.compile('sgd', 'mse') - x = np.random.random((num_samples, timesteps, embedding_dim)) - y = np.random.random((num_samples, units)) - model.train_on_batch(x, y) - + layer = keras.layers.GRU(units, input_shape=(None, embedding_dim)) + model = keras.models.Sequential() + model.add(layer) + model.compile(RMSPropOptimizer(0.01), 'mse') + x = np.random.random((num_samples, timesteps, embedding_dim)) + y = np.random.random((num_samples, units)) + model.train_on_batch(x, y) + + @tf_test_util.run_in_graph_and_eager_modes() def test_dropout_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.GRU, - kwargs={'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1}, - input_shape=(num_samples, timesteps, embedding_dim)) - + testing_utils.layer_test( + keras.layers.GRU, + kwargs={'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1}, + input_shape=(num_samples, timesteps, embedding_dim)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_implementation_mode_GRU(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - for mode in [0, 1, 2]: - testing_utils.layer_test( - keras.layers.GRU, - kwargs={'units': units, - 'implementation': mode}, - input_shape=(num_samples, timesteps, embedding_dim)) + for mode in [0, 1, 2]: + testing_utils.layer_test( + keras.layers.GRU, + kwargs={'units': units, + 'implementation': mode}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_statefulness_GRU(self): num_samples = 2 diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py index 798ac236a30a438107caed939f7650f51b62ef42..13d96e939220c11a4090cf535e3efa4365fe8b62 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local.py +++ b/tensorflow/python/keras/_impl/keras/layers/local.py @@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils import conv_utils from tensorflow.python.util.tf_export import tf_export @@ -53,7 +53,7 @@ class LocallyConnected1D(Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of a single integer, specifying the length of the 1D convolution window. strides: An integer or tuple/list of a single integer, @@ -222,7 +222,7 @@ class LocallyConnected2D(Layer): Arguments: filters: Integer, the dimensionality of the output space - (i.e. the number output of filters in the convolution). + (i.e. the number of output filters in the convolution). kernel_size: An integer or tuple/list of 2 integers, specifying the width and height of the 2D convolution window. Can be a single integer to specify the same value for diff --git a/tensorflow/python/keras/_impl/keras/layers/local_test.py b/tensorflow/python/keras/_impl/keras/layers/local_test.py index a815a0fadc8215c00f3db4749e323f96e44b66f3..93741d24b9a74cf9e8a83069f7c4235b1f489818 100644 --- a/tensorflow/python/keras/_impl/keras/layers/local_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/local_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -27,6 +28,7 @@ from tensorflow.python.platform import test class LocallyConnectedLayersTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_locallyconnected_1d(self): num_samples = 2 num_steps = 8 @@ -39,16 +41,15 @@ class LocallyConnectedLayersTest(test.TestCase): if padding == 'same' and strides != 1: continue - with self.test_session(): - testing_utils.layer_test( - keras.layers.LocallyConnected1D, - kwargs={ - 'filters': filters, - 'kernel_size': filter_length, - 'padding': padding, - 'strides': strides - }, - input_shape=(num_samples, num_steps, input_dim)) + testing_utils.layer_test( + keras.layers.LocallyConnected1D, + kwargs={ + 'filters': filters, + 'kernel_size': filter_length, + 'padding': padding, + 'strides': strides + }, + input_shape=(num_samples, num_steps, input_dim)) def test_locallyconnected_1d_regularization(self): num_samples = 2 @@ -86,6 +87,7 @@ class LocallyConnectedLayersTest(test.TestCase): self.assertEqual(layer.kernel.constraint, k_constraint) self.assertEqual(layer.bias.constraint, b_constraint) + @tf_test_util.run_in_graph_and_eager_modes() def test_locallyconnected_2d(self): num_samples = 8 filters = 3 @@ -98,20 +100,18 @@ class LocallyConnectedLayersTest(test.TestCase): if padding == 'same' and strides != (1, 1): continue - with self.test_session(): - testing_utils.layer_test( - keras.layers.LocallyConnected2D, - kwargs={ - 'filters': filters, - 'kernel_size': 3, - 'padding': padding, - 'kernel_regularizer': 'l2', - 'bias_regularizer': 'l2', - 'activity_regularizer': 'l2', - 'strides': strides, - 'data_format': 'channels_last' - }, - input_shape=(num_samples, num_row, num_col, stack_size)) + testing_utils.layer_test( + keras.layers.LocallyConnected2D, + kwargs={ + 'filters': filters, + 'kernel_size': 3, + 'padding': padding, + 'kernel_regularizer': 'l2', + 'bias_regularizer': 'l2', + 'strides': strides, + 'data_format': 'channels_last' + }, + input_shape=(num_samples, num_row, num_col, stack_size)) def test_locallyconnected_2d_channels_first(self): num_samples = 8 diff --git a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py index deb1d7c0c685e51ed756cbcdd5aec81ee60b5f96..11a5e0aeaacfa7520361ae41ac3d40607e8a9050 100644 --- a/tensorflow/python/keras/_impl/keras/layers/lstm_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/lstm_test.py @@ -20,28 +20,29 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class LSTMLayerTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_return_sequences_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.LSTM, - kwargs={'units': units, - 'return_sequences': True}, - input_shape=(num_samples, timesteps, embedding_dim)) + testing_utils.layer_test( + keras.layers.LSTM, + kwargs={'units': units, + 'return_sequences': True}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_static_shape_inference_LSTM(self): # Github issue: 15165 - num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 @@ -53,48 +54,47 @@ class LSTMLayerTest(test.TestCase): layer = keras.layers.LSTM(units, return_sequences=True) model.add(layer) outputs = model.layers[-1].output - self.assertEquals(outputs.get_shape().as_list(), - [None, timesteps, units]) + self.assertEquals(outputs.get_shape().as_list(), [None, timesteps, units]) + @tf_test_util.run_in_graph_and_eager_modes() def test_dynamic_behavior_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim)) - model = keras.models.Sequential() - model.add(layer) - model.compile('sgd', 'mse') - x = np.random.random((num_samples, timesteps, embedding_dim)) - y = np.random.random((num_samples, units)) - model.train_on_batch(x, y) + layer = keras.layers.LSTM(units, input_shape=(None, embedding_dim)) + model = keras.models.Sequential() + model.add(layer) + model.compile(RMSPropOptimizer(0.001), 'mse') + x = np.random.random((num_samples, timesteps, embedding_dim)) + y = np.random.random((num_samples, units)) + model.train_on_batch(x, y) + @tf_test_util.run_in_graph_and_eager_modes() def test_dropout_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.LSTM, - kwargs={'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1}, - input_shape=(num_samples, timesteps, embedding_dim)) - + testing_utils.layer_test( + keras.layers.LSTM, + kwargs={'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1}, + input_shape=(num_samples, timesteps, embedding_dim)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_implementation_mode_LSTM(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - for mode in [0, 1, 2]: - testing_utils.layer_test( - keras.layers.LSTM, - kwargs={'units': units, - 'implementation': mode}, - input_shape=(num_samples, timesteps, embedding_dim)) + for mode in [0, 1, 2]: + testing_utils.layer_test( + keras.layers.LSTM, + kwargs={'units': units, + 'implementation': mode}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_statefulness_LSTM(self): num_samples = 2 diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py index cdf2878e83e32147d30d6b29742b7e9013a1facb..c660cbd449b11a139f64cfa8b3a35310a597491c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge.py @@ -21,8 +21,8 @@ from __future__ import division from __future__ import print_function from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras.engine.topology import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import Layer +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/merge_test.py b/tensorflow/python/keras/_impl/keras/layers/merge_test.py index bb03dda1fc645222c1ced97cfce8d459586dd89d..b2fe06f93e33ed63d6a2aa29522ecb552f582440 100644 --- a/tensorflow/python/keras/_impl/keras/layers/merge_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/merge_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -27,24 +28,25 @@ from tensorflow.python.platform import test class MergeLayersTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_add(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - i3 = keras.layers.Input(shape=(4, 5)) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + i3 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.add([i1, i2, i3]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2, i3], o) + o = keras.layers.add([i1, i2, i3]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2, i3], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - x3 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2, x3]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, x1 + x2 + x3, atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + x3 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2, x3]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, x1 + x2 + x3, atol=1e-4) - # test masking + def test_merge_add_masking(self): + with self.test_session(): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) m1 = keras.layers.Masking()(i1) @@ -54,11 +56,13 @@ class MergeLayersTest(test.TestCase): mask = layer.output_mask self.assertListEqual(mask.get_shape().as_list(), [None, 4]) - # test missing shape + def test_merge_add_dynamic_shape(self): + with self.test_session(): i1 = array_ops.placeholder(shape=(4, None), dtype='float32') i2 = array_ops.placeholder(shape=(4, 5), dtype='float32') layer = keras.layers.Add() o = layer([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [4, 5]) def test_merge_elementwise_errors(self): i1 = keras.layers.Input(shape=(4, 5)) @@ -72,79 +76,82 @@ class MergeLayersTest(test.TestCase): with self.assertRaises(ValueError): keras.layers.add([i1]) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_multiply(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - i3 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.multiply([i1, i2, i3]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2, i3], o) - - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - x3 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2, x3]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) - + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + i3 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.multiply([i1, i2, i3]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2, i3], o) + + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + x3 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2, x3]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, x1 * x2 * x3, atol=1e-4) + + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_average(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.average([i1, i2]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2], o) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.average([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, 0.5 * (x1 + x2), atol=1e-4) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_maximum(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.maximum([i1, i2]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2], o) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.maximum([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, np.maximum(x1, x2), atol=1e-4) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_minimum(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4, 5)) - i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.minimum([i1, i2]) - self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) - model = keras.models.Model([i1, i2], o) + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.minimum([i1, i2]) + self.assertListEqual(o.get_shape().as_list(), [None, 4, 5]) + model = keras.models.Model([i1, i2], o) - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 4, 5)) - self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 4, 5)) + self.assertAllClose(out, np.minimum(x1, x2), atol=1e-4) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_concatenate(self): + i1 = keras.layers.Input(shape=(4, 5)) + i2 = keras.layers.Input(shape=(4, 5)) + o = keras.layers.concatenate([i1, i2], axis=1) + self.assertListEqual(o.get_shape().as_list(), [None, 8, 5]) + model = keras.models.Model([i1, i2], o) + + x1 = np.random.random((2, 4, 5)) + x2 = np.random.random((2, 4, 5)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 8, 5)) + self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4) + + def test_merge_concatenate_masking(self): with self.test_session(): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) - o = keras.layers.concatenate([i1, i2], axis=1) - self.assertListEqual(o.get_shape().as_list(), [None, 8, 5]) - model = keras.models.Model([i1, i2], o) - - x1 = np.random.random((2, 4, 5)) - x2 = np.random.random((2, 4, 5)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 8, 5)) - self.assertAllClose(out, np.concatenate([x1, x2], axis=1), atol=1e-4) - - # test masking m1 = keras.layers.Masking()(i1) layer = keras.layers.Concatenate() o = layer([m1, i2]) @@ -162,35 +169,35 @@ class MergeLayersTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'called on a list'): keras.layers.concatenate([i1], axis=-1) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_dot(self): - with self.test_session(): - i1 = keras.layers.Input(shape=(4,)) - i2 = keras.layers.Input(shape=(4,)) - o = keras.layers.dot([i1, i2], axes=1) - self.assertListEqual(o.get_shape().as_list(), [None, 1]) - model = keras.models.Model([i1, i2], o) - _ = keras.layers.Dot(axes=1).get_config() - - x1 = np.random.random((2, 4)) - x2 = np.random.random((2, 4)) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 1)) - expected = np.zeros((2, 1)) - expected[0, 0] = np.dot(x1[0], x2[0]) - expected[1, 0] = np.dot(x1[1], x2[1]) - self.assertAllClose(out, expected, atol=1e-4) - - # Test with negative tuple of axes. - o = keras.layers.dot([i1, i2], axes=(-1, -1)) - self.assertListEqual(o.get_shape().as_list(), [None, 1]) - model = keras.models.Model([i1, i2], o) - out = model.predict([x1, x2]) - self.assertEqual(out.shape, (2, 1)) - self.assertAllClose(out, expected, atol=1e-4) - - # test compute_output_shape - layer = keras.layers.Dot(axes=-1) - self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1)) + i1 = keras.layers.Input(shape=(4,)) + i2 = keras.layers.Input(shape=(4,)) + o = keras.layers.dot([i1, i2], axes=1) + self.assertListEqual(o.get_shape().as_list(), [None, 1]) + model = keras.models.Model([i1, i2], o) + _ = keras.layers.Dot(axes=1).get_config() + + x1 = np.random.random((2, 4)) + x2 = np.random.random((2, 4)) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 1)) + expected = np.zeros((2, 1)) + expected[0, 0] = np.dot(x1[0], x2[0]) + expected[1, 0] = np.dot(x1[1], x2[1]) + self.assertAllClose(out, expected, atol=1e-4) + + # Test with negative tuple of axes. + o = keras.layers.dot([i1, i2], axes=(-1, -1)) + self.assertListEqual(o.get_shape().as_list(), [None, 1]) + model = keras.models.Model([i1, i2], o) + out = model.predict([x1, x2]) + self.assertEqual(out.shape, (2, 1)) + self.assertAllClose(out, expected, atol=1e-4) + + # test compute_output_shape + layer = keras.layers.Dot(axes=-1) + self.assertEqual(layer.compute_output_shape([(4, 5), (4, 5)]), (4, 1)) def test_dot_errors(self): i1 = keras.layers.Input(shape=(4, 5)) @@ -208,6 +215,7 @@ class MergeLayersTest(test.TestCase): dot = keras.layers.Dot(1) dot.compute_output_shape(1) + @tf_test_util.run_in_graph_and_eager_modes() def test_merge_subtract(self): i1 = keras.layers.Input(shape=(4, 5)) i2 = keras.layers.Input(shape=(4, 5)) diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py index 9010f4961585af58b7eae43dcd224e0c39606239..e309d160e5a9be97ff5f5356dad9dfaf85430233 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise.py @@ -22,7 +22,7 @@ import numpy as np from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.util.tf_export import tf_export diff --git a/tensorflow/python/keras/_impl/keras/layers/noise_test.py b/tensorflow/python/keras/_impl/keras/layers/noise_test.py index f9b4d9cd090ffec1a5acd9118ea6a65798bd72a6..af4f031ec95bb56b72c1f1018e0e529d8ff55564 100644 --- a/tensorflow/python/keras/_impl/keras/layers/noise_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/noise_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -39,12 +40,12 @@ class NoiseLayersTest(test.TestCase): kwargs={'rate': 0.5}, input_shape=(3, 2, 3)) + @tf_test_util.run_in_graph_and_eager_modes() def test_AlphaDropout(self): - with self.test_session(): - testing_utils.layer_test( - keras.layers.AlphaDropout, - kwargs={'rate': 0.2}, - input_shape=(3, 2, 3)) + testing_utils.layer_test( + keras.layers.AlphaDropout, + kwargs={'rate': 0.2}, + input_shape=(3, 2, 3)) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py index ec0a5ae560f49ee39ecffb64f4ac65d3e800024c..70049f0976b7170005183bb4b028079b39a23afb 100644 --- a/tensorflow/python/keras/_impl/keras/layers/pooling_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/pooling_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test @@ -25,81 +27,85 @@ from tensorflow.python.platform import test class GlobalPoolingTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_globalpooling_1d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, - input_shape=(3, 4, 5)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) + testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D, + input_shape=(3, 4, 5)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5)) + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_globalpooling_2d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling2D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 5, 6)) - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling2D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 5, 6, 4)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling2D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 5, 6)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling2D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 5, 6, 4)) - + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling2D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 5, 6)) + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling2D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 5, 6, 4)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling2D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 5, 6)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling2D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 5, 6, 4)) + + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_globalpooling_3d(self): - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling3D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 3, 4, 3)) - testing_utils.layer_test( - keras.layers.pooling.GlobalMaxPooling3D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 4, 3, 4, 3)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling3D, - kwargs={'data_format': 'channels_first'}, - input_shape=(3, 4, 3, 4, 3)) - testing_utils.layer_test( - keras.layers.pooling.GlobalAveragePooling3D, - kwargs={'data_format': 'channels_last'}, - input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling3D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalMaxPooling3D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling3D, + kwargs={'data_format': 'channels_first'}, + input_shape=(3, 4, 3, 4, 3)) + testing_utils.layer_test( + keras.layers.pooling.GlobalAveragePooling3D, + kwargs={'data_format': 'channels_last'}, + input_shape=(3, 4, 3, 4, 3)) class Pooling2DTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_maxpooling_2d(self): pool_size = (3, 3) - with self.test_session(use_gpu=True): - for strides in [(1, 1), (2, 2)]: - testing_utils.layer_test( - keras.layers.MaxPooling2D, - kwargs={ - 'strides': strides, - 'padding': 'valid', - 'pool_size': pool_size - }, - input_shape=(3, 5, 6, 4)) - - def test_averagepooling_2d(self): - with self.test_session(use_gpu=True): + for strides in [(1, 1), (2, 2)]: testing_utils.layer_test( - keras.layers.AveragePooling2D, - kwargs={'strides': (2, 2), - 'padding': 'same', - 'pool_size': (2, 2)}, - input_shape=(3, 5, 6, 4)) - testing_utils.layer_test( - keras.layers.AveragePooling2D, - kwargs={'strides': (2, 2), - 'padding': 'valid', - 'pool_size': (3, 3)}, + keras.layers.MaxPooling2D, + kwargs={ + 'strides': strides, + 'padding': 'valid', + 'pool_size': pool_size + }, input_shape=(3, 5, 6, 4)) + + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) + def test_averagepooling_2d(self): + testing_utils.layer_test( + keras.layers.AveragePooling2D, + kwargs={'strides': (2, 2), + 'padding': 'same', + 'pool_size': (2, 2)}, + input_shape=(3, 5, 6, 4)) + testing_utils.layer_test( + keras.layers.AveragePooling2D, + kwargs={'strides': (2, 2), + 'padding': 'valid', + 'pool_size': (3, 3)}, + input_shape=(3, 5, 6, 4)) + + # This part of the test can only run on GPU but doesn't appear + # to be properly assigned to a GPU when running in eager mode. + if not context.in_eager_mode(): # Only runs on GPU with CUDA, channels_first is not supported on CPU. # TODO(b/62340061): Support channels_first on CPU. if test.is_gpu_available(cuda_only=True): @@ -116,66 +122,66 @@ class Pooling2DTest(test.TestCase): class Pooling3DTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_maxpooling_3d(self): pool_size = (3, 3, 3) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.MaxPooling3D, - kwargs={'strides': 2, - 'padding': 'valid', - 'pool_size': pool_size}, - input_shape=(3, 11, 12, 10, 4)) - testing_utils.layer_test( - keras.layers.MaxPooling3D, - kwargs={ - 'strides': 3, - 'padding': 'valid', - 'data_format': 'channels_first', - 'pool_size': pool_size - }, - input_shape=(3, 4, 11, 12, 10)) - + testing_utils.layer_test( + keras.layers.MaxPooling3D, + kwargs={'strides': 2, + 'padding': 'valid', + 'pool_size': pool_size}, + input_shape=(3, 11, 12, 10, 4)) + testing_utils.layer_test( + keras.layers.MaxPooling3D, + kwargs={ + 'strides': 3, + 'padding': 'valid', + 'data_format': 'channels_first', + 'pool_size': pool_size + }, + input_shape=(3, 4, 11, 12, 10)) + + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_averagepooling_3d(self): pool_size = (3, 3, 3) - with self.test_session(use_gpu=True): - testing_utils.layer_test( - keras.layers.AveragePooling3D, - kwargs={'strides': 2, - 'padding': 'valid', - 'pool_size': pool_size}, - input_shape=(3, 11, 12, 10, 4)) - testing_utils.layer_test( - keras.layers.AveragePooling3D, - kwargs={ - 'strides': 3, - 'padding': 'valid', - 'data_format': 'channels_first', - 'pool_size': pool_size - }, - input_shape=(3, 4, 11, 12, 10)) + testing_utils.layer_test( + keras.layers.AveragePooling3D, + kwargs={'strides': 2, + 'padding': 'valid', + 'pool_size': pool_size}, + input_shape=(3, 11, 12, 10, 4)) + testing_utils.layer_test( + keras.layers.AveragePooling3D, + kwargs={ + 'strides': 3, + 'padding': 'valid', + 'data_format': 'channels_first', + 'pool_size': pool_size + }, + input_shape=(3, 4, 11, 12, 10)) class Pooling1DTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_maxpooling_1d(self): - with self.test_session(use_gpu=True): - for padding in ['valid', 'same']: - for stride in [1, 2]: - testing_utils.layer_test( - keras.layers.MaxPooling1D, - kwargs={'strides': stride, - 'padding': padding}, - input_shape=(3, 5, 4)) + for padding in ['valid', 'same']: + for stride in [1, 2]: + testing_utils.layer_test( + keras.layers.MaxPooling1D, + kwargs={'strides': stride, + 'padding': padding}, + input_shape=(3, 5, 4)) + @tf_test_util.run_in_graph_and_eager_modes(use_gpu=True) def test_averagepooling_1d(self): - with self.test_session(use_gpu=True): - for padding in ['valid', 'same']: - for stride in [1, 2]: - testing_utils.layer_test( - keras.layers.AveragePooling1D, - kwargs={'strides': stride, - 'padding': padding}, - input_shape=(3, 5, 4)) + for padding in ['valid', 'same']: + for stride in [1, 2]: + testing_utils.layer_test( + keras.layers.AveragePooling1D, + kwargs={'strides': stride, + 'padding': padding}, + input_shape=(3, 5, 4)) if __name__ == '__main__': diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py index 45f6711c77224875328ba346e6297fad3a681cb6..0264c7ae0119b36261a0a5467576c47a12a30801 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py @@ -22,6 +22,7 @@ from __future__ import print_function import numbers import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import activations from tensorflow.python.keras._impl.keras import backend as K @@ -30,7 +31,7 @@ from tensorflow.python.keras._impl.keras import initializers from tensorflow.python.keras._impl.keras import regularizers from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export @@ -88,7 +89,7 @@ class StackedRNNCells(Layer): state_size.append(cell.state_size) return tuple(state_size) - def call(self, inputs, states, **kwargs): + def call(self, inputs, states, constants=None, **kwargs): # Recover per-cell states. nested_states = [] for cell in self.cells[::-1]: @@ -103,7 +104,12 @@ class StackedRNNCells(Layer): # Call the cells in order and store the returned states. new_nested_states = [] for cell, states in zip(self.cells, nested_states): - inputs, states = cell.call(inputs, states, **kwargs) + if has_arg(cell.call, 'constants'): + inputs, states = cell.call(inputs, states, constants=constants, + **kwargs) + else: + inputs, states = cell.call(inputs, states, **kwargs) + new_nested_states.append(states) # Format the new states as a flat list @@ -115,9 +121,15 @@ class StackedRNNCells(Layer): @shape_type_conversion def build(self, input_shape): + if isinstance(input_shape, list): + constants_shape = input_shape[1:] + input_shape = input_shape[0] for cell in self.cells: if isinstance(cell, Layer): - cell.build(input_shape) + if has_arg(cell.call, 'constants'): + cell.build([input_shape] + constants_shape) + else: + cell.build(input_shape) if hasattr(cell.state_size, '__len__'): output_dim = cell.state_size[0] else: @@ -528,12 +540,14 @@ class RNN(Layer): self._num_constants = len(constants) additional_specs += self.constants_spec # at this point additional_inputs cannot be empty - is_keras_tensor = hasattr(additional_inputs[0], '_keras_history') + is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) for tensor in additional_inputs: - if hasattr(tensor, '_keras_history') != is_keras_tensor: + if K.is_keras_tensor(tensor) != is_keras_tensor: raise ValueError('The initial state or constants of an RNN' ' layer cannot be specified with a mix of' - ' Keras tensors and non-Keras tensors') + ' Keras tensors and non-Keras tensors' + '(a "Keras tensor" is a tensor that was' + 'returned by a Keras layer, or by `Input`)') if is_keras_tensor: # Compute the full input spec, including state and constants @@ -797,7 +811,8 @@ class SimpleRNNCell(Layer): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, @@ -921,7 +936,9 @@ class SimpleRNNCell(Layer): # Properly set learning phase on output tensor. if 0 < self.dropout + self.recurrent_dropout: - if training is None: + if training is None and not context.in_eager_mode(): + # This would be harmless to set in eager mode, but eager tensors + # disallow setting arbitrary attributes. output._uses_learning_phase = True return output, [output] @@ -967,6 +984,7 @@ class SimpleRNN(RNN): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. @@ -1177,10 +1195,14 @@ class GRUCell(Layer): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. + Default: hyperbolic tangent (`tanh`). If you pass None, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. @@ -1280,23 +1302,6 @@ class GRUCell(Layer): constraint=self.bias_constraint) else: self.bias = None - - self.kernel_z = self.kernel[:, :self.units] - self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units] - self.kernel_r = self.kernel[:, self.units:self.units * 2] - self.recurrent_kernel_r = self.recurrent_kernel[:, self.units: - self.units * 2] - self.kernel_h = self.kernel[:, self.units * 2:] - self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:] - - if self.use_bias: - self.bias_z = self.bias[:self.units] - self.bias_r = self.bias[self.units:self.units * 2] - self.bias_h = self.bias[self.units * 2:] - else: - self.bias_z = None - self.bias_r = None - self.bias_h = None self.built = True def call(self, inputs, states, training=None): @@ -1331,13 +1336,13 @@ class GRUCell(Layer): inputs_z = inputs inputs_r = inputs inputs_h = inputs - x_z = K.dot(inputs_z, self.kernel_z) - x_r = K.dot(inputs_r, self.kernel_r) - x_h = K.dot(inputs_h, self.kernel_h) + x_z = K.dot(inputs_z, self.kernel[:, :self.units]) + x_r = K.dot(inputs_r, self.kernel[:, self.units:self.units * 2]) + x_h = K.dot(inputs_h, self.kernel[:, self.units * 2:]) if self.use_bias: - x_z = K.bias_add(x_z, self.bias_z) - x_r = K.bias_add(x_r, self.bias_r) - x_h = K.bias_add(x_h, self.bias_h) + x_z = K.bias_add(x_z, self.bias[:self.units]) + x_r = K.bias_add(x_r, self.bias[self.units:self.units * 2]) + x_h = K.bias_add(x_h, self.bias[self.units * 2:]) if 0. < self.recurrent_dropout < 1.: h_tm1_z = h_tm1 * rec_dp_mask[0] @@ -1348,11 +1353,14 @@ class GRUCell(Layer): h_tm1_r = h_tm1 h_tm1_h = h_tm1 z = self.recurrent_activation( - x_z + K.dot(h_tm1_z, self.recurrent_kernel_z)) + x_z + K.dot(h_tm1_z, self.recurrent_kernel[:, :self.units])) r = self.recurrent_activation( - x_r + K.dot(h_tm1_r, self.recurrent_kernel_r)) + x_r + K.dot(h_tm1_r, self.recurrent_kernel[:, self.units: + self.units * 2])) - hh = self.activation(x_h + K.dot(r * h_tm1_h, self.recurrent_kernel_h)) + hh = self.activation(x_h + K.dot(r * h_tm1_h, + self.recurrent_kernel[:, + self.units * 2:])) else: if 0. < self.dropout < 1.: inputs *= dp_mask[0] @@ -1376,44 +1384,34 @@ class GRUCell(Layer): hh = self.activation(x_h + recurrent_h) h = z * h_tm1 + (1 - z) * hh if 0 < self.dropout + self.recurrent_dropout: - if training is None: + if training is None and not context.in_eager_mode(): + # This would be harmless to set in eager mode, but eager tensors + # disallow setting arbitrary attributes. h._uses_learning_phase = True return h, [h] def get_config(self): config = { - 'units': - self.units, - 'activation': - activations.serialize(self.activation), + 'units': self.units, + 'activation': activations.serialize(self.activation), 'recurrent_activation': activations.serialize(self.recurrent_activation), - 'use_bias': - self.use_bias, - 'kernel_initializer': - initializers.serialize(self.kernel_initializer), + 'use_bias': self.use_bias, + 'kernel_initializer': initializers.serialize(self.kernel_initializer), 'recurrent_initializer': initializers.serialize(self.recurrent_initializer), - 'bias_initializer': - initializers.serialize(self.bias_initializer), - 'kernel_regularizer': - regularizers.serialize(self.kernel_regularizer), + 'bias_initializer': initializers.serialize(self.bias_initializer), + 'kernel_regularizer': regularizers.serialize(self.kernel_regularizer), 'recurrent_regularizer': regularizers.serialize(self.recurrent_regularizer), - 'bias_regularizer': - regularizers.serialize(self.bias_regularizer), - 'kernel_constraint': - constraints.serialize(self.kernel_constraint), + 'bias_regularizer': regularizers.serialize(self.bias_regularizer), + 'kernel_constraint': constraints.serialize(self.kernel_constraint), 'recurrent_constraint': constraints.serialize(self.recurrent_constraint), - 'bias_constraint': - constraints.serialize(self.bias_constraint), - 'dropout': - self.dropout, - 'recurrent_dropout': - self.recurrent_dropout, - 'implementation': - self.implementation + 'bias_constraint': constraints.serialize(self.bias_constraint), + 'dropout': self.dropout, + 'recurrent_dropout': self.recurrent_dropout, + 'implementation': self.implementation } base_config = super(GRUCell, self).get_config() return dict(list(base_config.items()) + list(config.items())) @@ -1428,10 +1426,14 @@ class GRU(RNN): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. @@ -1662,10 +1664,14 @@ class LSTMCell(Layer): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`).x use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs. @@ -1782,29 +1788,6 @@ class LSTMCell(Layer): constraint=self.bias_constraint) else: self.bias = None - - self.kernel_i = self.kernel[:, :self.units] - self.kernel_f = self.kernel[:, self.units:self.units * 2] - self.kernel_c = self.kernel[:, self.units * 2:self.units * 3] - self.kernel_o = self.kernel[:, self.units * 3:] - - self.recurrent_kernel_i = self.recurrent_kernel[:, :self.units] - self.recurrent_kernel_f = self.recurrent_kernel[:, self.units: - self.units * 2] - self.recurrent_kernel_c = self.recurrent_kernel[:, self.units * 2: - self.units * 3] - self.recurrent_kernel_o = self.recurrent_kernel[:, self.units * 3:] - - if self.use_bias: - self.bias_i = self.bias[:self.units] - self.bias_f = self.bias[self.units:self.units * 2] - self.bias_c = self.bias[self.units * 2:self.units * 3] - self.bias_o = self.bias[self.units * 3:] - else: - self.bias_i = None - self.bias_f = None - self.bias_c = None - self.bias_o = None self.built = True def call(self, inputs, states, training=None): @@ -1842,15 +1825,15 @@ class LSTMCell(Layer): inputs_f = inputs inputs_c = inputs inputs_o = inputs - x_i = K.dot(inputs_i, self.kernel_i) - x_f = K.dot(inputs_f, self.kernel_f) - x_c = K.dot(inputs_c, self.kernel_c) - x_o = K.dot(inputs_o, self.kernel_o) + x_i = K.dot(inputs_i, self.kernel[:, :self.units]) + x_f = K.dot(inputs_f, self.kernel[:, self.units:self.units * 2]) + x_c = K.dot(inputs_c, self.kernel[:, self.units * 2:self.units * 3]) + x_o = K.dot(inputs_o, self.kernel[:, self.units * 3:]) if self.use_bias: - x_i = K.bias_add(x_i, self.bias_i) - x_f = K.bias_add(x_f, self.bias_f) - x_c = K.bias_add(x_c, self.bias_c) - x_o = K.bias_add(x_o, self.bias_o) + x_i = K.bias_add(x_i, self.bias[:self.units]) + x_f = K.bias_add(x_f, self.bias[self.units:self.units * 2]) + x_c = K.bias_add(x_c, self.bias[self.units * 2:self.units * 3]) + x_o = K.bias_add(x_o, self.bias[self.units * 3:]) if 0 < self.recurrent_dropout < 1.: h_tm1_i = h_tm1 * rec_dp_mask[0] @@ -1863,13 +1846,15 @@ class LSTMCell(Layer): h_tm1_c = h_tm1 h_tm1_o = h_tm1 i = self.recurrent_activation( - x_i + K.dot(h_tm1_i, self.recurrent_kernel_i)) + x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])) f = self.recurrent_activation( - x_f + K.dot(h_tm1_f, self.recurrent_kernel_f)) + x_f + K.dot(h_tm1_f, + self.recurrent_kernel[:, self.units: self.units * 2])) c = f * c_tm1 + i * self.activation( - x_c + K.dot(h_tm1_c, self.recurrent_kernel_c)) + x_c + K.dot(h_tm1_c, + self.recurrent_kernel[:, self.units * 2: self.units * 3])) o = self.recurrent_activation( - x_o + K.dot(h_tm1_o, self.recurrent_kernel_o)) + x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])) else: if 0. < self.dropout < 1.: inputs *= dp_mask[0] @@ -1892,7 +1877,9 @@ class LSTMCell(Layer): h = o * self.activation(c) if 0 < self.dropout + self.recurrent_dropout: - if training is None: + if training is None and not context.in_eager_mode(): + # This would be harmless to set in eager mode, but eager tensors + # disallow setting arbitrary attributes. h._uses_learning_phase = True return h, [h, c] @@ -1944,10 +1931,14 @@ class LSTM(RNN): Arguments: units: Positive integer, dimensionality of the output space. activation: Activation function to use. - If you pass None, no activation is applied + Default: hyperbolic tangent (`tanh`). + If you pass `None`, no activation is applied (ie. "linear" activation: `a(x) = x`). recurrent_activation: Activation function to use for the recurrent step. + Default: hard sigmoid (`hard_sigmoid`). + If you pass `None`, no activation is applied + (ie. "linear" activation: `a(x) = x`). use_bias: Boolean, whether the layer uses a bias vector. kernel_initializer: Initializer for the `kernel` weights matrix, used for the linear transformation of the inputs.. diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py index ab48a63e3544534567ee3205bb74174cda6e1769..de022153f6f07240a0dff70e5faeed5b6d4a8c5f 100644 --- a/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/recurrent_test.py @@ -253,7 +253,7 @@ class RNNTest(test.TestCase): self.assertAllClose(y_np, y_np_2, atol=1e-4) with self.test_session(): - # test flat list inputs + # test flat list inputs. with keras.utils.CustomObjectScope(custom_objects): layer = keras.layers.RNN.from_config(config.copy()) y = layer([x, c]) @@ -262,6 +262,35 @@ class RNNTest(test.TestCase): y_np_3 = model.predict([x_np, c_np]) self.assertAllClose(y_np, y_np_3, atol=1e-4) + with self.test_session(): + # Test stacking. + cells = [keras.layers.recurrent.GRUCell(8), + RNNCellWithConstants(12), + RNNCellWithConstants(32)] + layer = keras.layers.recurrent.RNN(cells) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + with self.test_session(): + # Test stacked RNN serialization + x_np = np.random.random((6, 5, 5)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + with keras.utils.CustomObjectScope(custom_objects): + layer = keras.layers.recurrent.RNN.from_config(config.copy()) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, c_np]) + self.assertAllClose(y_np, y_np_2, atol=1e-4) + def test_rnn_cell_with_constants_layer_passing_initial_state(self): class RNNCellWithConstants(keras.layers.Layer): diff --git a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py index 7edebdacd07d74fe6b5a982d12645fb5556bdf75..8c7189cd4718450a85c015e08ab3a58cc5d86531 100644 --- a/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/simplernn_test.py @@ -20,64 +20,66 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.keras._impl.keras import testing_utils from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class SimpleRNNLayerTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_return_sequences_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.SimpleRNN, - kwargs={'units': units, - 'return_sequences': True}, - input_shape=(num_samples, timesteps, embedding_dim)) + testing_utils.layer_test( + keras.layers.SimpleRNN, + kwargs={'units': units, + 'return_sequences': True}, + input_shape=(num_samples, timesteps, embedding_dim)) + @tf_test_util.run_in_graph_and_eager_modes() def test_dynamic_behavior_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - layer = keras.layers.SimpleRNN(units, input_shape=(None, embedding_dim)) - model = keras.models.Sequential() - model.add(layer) - model.compile('sgd', 'mse') - x = np.random.random((num_samples, timesteps, embedding_dim)) - y = np.random.random((num_samples, units)) - model.train_on_batch(x, y) - + layer = keras.layers.SimpleRNN(units, input_shape=(None, embedding_dim)) + model = keras.models.Sequential() + model.add(layer) + model.compile(RMSPropOptimizer(0.01), 'mse') + x = np.random.random((num_samples, timesteps, embedding_dim)) + y = np.random.random((num_samples, units)) + model.train_on_batch(x, y) + + @tf_test_util.run_in_graph_and_eager_modes() def test_dropout_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - testing_utils.layer_test( - keras.layers.SimpleRNN, - kwargs={'units': units, - 'dropout': 0.1, - 'recurrent_dropout': 0.1}, - input_shape=(num_samples, timesteps, embedding_dim)) - + testing_utils.layer_test( + keras.layers.SimpleRNN, + kwargs={'units': units, + 'dropout': 0.1, + 'recurrent_dropout': 0.1}, + input_shape=(num_samples, timesteps, embedding_dim)) + + @tf_test_util.run_in_graph_and_eager_modes() def test_implementation_mode_SimpleRNN(self): num_samples = 2 timesteps = 3 embedding_dim = 4 units = 2 - with self.test_session(): - for mode in [0, 1, 2]: - testing_utils.layer_test( - keras.layers.SimpleRNN, - kwargs={'units': units, - 'implementation': mode}, - input_shape=(num_samples, timesteps, embedding_dim)) + for mode in [0, 1, 2]: + testing_utils.layer_test( + keras.layers.SimpleRNN, + kwargs={'units': units, + 'implementation': mode}, + input_shape=(num_samples, timesteps, embedding_dim)) def test_statefulness_SimpleRNN(self): num_samples = 2 diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py index f053aa1d09570e76aa0b6b9733c0b0bb438e24a0..76ddd9299dd669da35d89a6fe8fc521ce4c26337 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py @@ -25,7 +25,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl.keras import backend as K from tensorflow.python.keras._impl.keras.engine import InputSpec from tensorflow.python.keras._impl.keras.engine import Layer -from tensorflow.python.keras._impl.keras.engine.topology import shape_type_conversion +from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg from tensorflow.python.layers import utils as tf_layers_util from tensorflow.python.util.tf_export import tf_export @@ -61,6 +61,14 @@ class Wrapper(Layer): else: return None + @property + def trainable(self): + return self.layer.trainable + + @trainable.setter + def trainable(self, value): + self.layer.trainable = value + @property def trainable_weights(self): return self.layer.trainable_weights @@ -255,7 +263,6 @@ class Bidirectional(Wrapper): """ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs): - super(Bidirectional, self).__init__(layer, **kwargs) if merge_mode not in ['sum', 'mul', 'ave', 'concat', None]: raise ValueError('Invalid merge mode. ' 'Merge mode should be one of ' @@ -275,6 +282,19 @@ class Bidirectional(Wrapper): self.return_sequences = layer.return_sequences self.return_state = layer.return_state self.supports_masking = True + self._trainable = True + super(Bidirectional, self).__init__(layer, **kwargs) + self.input_spec = layer.input_spec + + @property + def trainable(self): + return self._trainable + + @trainable.setter + def trainable(self, value): + self._trainable = value + self.forward_layer.trainable = value + self.backward_layer.trainable = value def get_weights(self): return self.forward_layer.get_weights() + self.backward_layer.get_weights() @@ -305,6 +325,61 @@ class Bidirectional(Wrapper): return [output_shape] + state_shape + copy.copy(state_shape) return output_shape + def __call__(self, inputs, initial_state=None, **kwargs): + if isinstance(inputs, list): + if len(inputs) > 1: + initial_state = inputs[1:] + inputs = inputs[0] + + if initial_state is None: + return super(Bidirectional, self).__call__(inputs, **kwargs) + + # Standardize `initial_state` into list + if isinstance(initial_state, tuple): + initial_state = list(initial_state) + elif not isinstance(initial_state, list): + initial_state = [initial_state] + + # Check if `initial_state` can be splitted into half + num_states = len(initial_state) + if num_states % 2 > 0: + raise ValueError( + 'When passing `initial_state` to a Bidirectional RNN, the state ' + 'should be a list containing the states of the underlying RNNs. ' + 'Found: ' + str(initial_state)) + + # Applies the same workaround as in `RNN.__call__`, without handling + # constants + kwargs['initial_state'] = initial_state + additional_inputs = initial_state + additional_specs = [InputSpec(shape=K.int_shape(state)) + for state in initial_state] + self.forward_layer.state_spec = additional_specs[:num_states // 2] + self.backward_layer.state_spec = additional_specs[num_states // 2:] + + is_keras_tensor = K.is_keras_tensor(additional_inputs[0]) + for tensor in additional_inputs: + if K.is_keras_tensor(tensor) != is_keras_tensor: + raise ValueError('The initial state of a Bidirectional' + ' layer cannot be specified with a mix of' + ' Keras tensors and non-Keras tensors' + ' (a "Keras tensor" is a tensor that was' + ' returned by a Keras layer, or by `Input`)') + + if is_keras_tensor: + # Compute the full input spec, including state + full_input = [inputs] + additional_inputs + full_input_spec = self.input_spec + additional_specs + + # Perform the call with temporarily replaced input_spec + original_input_spec = self.input_spec + self.input_spec = full_input_spec + output = super(Bidirectional, self).__call__(full_input, **kwargs) + self.input_spec = original_input_spec + return output + else: + return super(Bidirectional, self).__call__(inputs, **kwargs) + def call(self, inputs, training=None, mask=None, initial_state=None): kwargs = {} if has_arg(self.layer.call, 'training'): @@ -313,11 +388,6 @@ class Bidirectional(Wrapper): kwargs['mask'] = mask if initial_state is not None and has_arg(self.layer.call, 'initial_state'): - if not isinstance(initial_state, list): - raise ValueError( - 'When passing `initial_state` to a Bidirectional RNN, the state ' - 'should be a list containing the states of the underlying RNNs. ' - 'Found: ' + str(initial_state)) forward_state = initial_state[:len(initial_state) // 2] backward_state = initial_state[len(initial_state) // 2:] y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs) diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py index f48c8919a148403874758b618aaa9a662e511240..8fcf66e90ff1289a06a996768ae5de2f1548a27c 100644 --- a/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/_impl/keras/layers/wrappers_test.py @@ -20,44 +20,43 @@ from __future__ import print_function import numpy as np +from tensorflow.python.framework import test_util as tf_test_util from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer class TimeDistributedTest(test.TestCase): + @tf_test_util.run_in_graph_and_eager_modes() def test_timedistributed_dense(self): - # first, test with Dense layer - with self.test_session(): - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(2), input_shape=(3, 4))) - model.compile(optimizer='rmsprop', loss='mse') - model.fit( - np.random.random((10, 3, 4)), - np.random.random((10, 3, 2)), - epochs=1, - batch_size=10) - - # test config - model.get_config() + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(2), input_shape=(3, 4))) + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') + model.fit( + np.random.random((10, 3, 4)), + np.random.random((10, 3, 2)), + epochs=1, + batch_size=10) + + # test config + model.get_config() def test_timedistributed_static_batch_size(self): - with self.test_session(): - model = keras.models.Sequential() - model.add( - keras.layers.TimeDistributed( - keras.layers.Dense(2), input_shape=(3, 4), batch_size=10)) - model.compile(optimizer='rmsprop', loss='mse') - model.fit( - np.random.random((10, 3, 4)), - np.random.random((10, 3, 2)), - epochs=1, - batch_size=10) + model = keras.models.Sequential() + model.add( + keras.layers.TimeDistributed( + keras.layers.Dense(2), input_shape=(3, 4), batch_size=10)) + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') + model.fit( + np.random.random((10, 3, 4)), + np.random.random((10, 3, 2)), + epochs=1, + batch_size=10) def test_timedistributed_conv2d(self): - # test with Conv2D with self.test_session(): model = keras.models.Sequential() model.add( @@ -73,7 +72,6 @@ class TimeDistributedTest(test.TestCase): model.summary() def test_timedistributed_stacked(self): - # test stacked layers with self.test_session(): model = keras.models.Sequential() model.add( @@ -133,6 +131,20 @@ class TimeDistributedTest(test.TestCase): # Verify input_map has one mapping from inputs to reshaped inputs. self.assertEqual(len(td._input_map.keys()), 1) + def test_TimeDistributed_trainable(self): + # test layers that need learning_phase to be set + x = keras.layers.Input(shape=(3, 2)) + layer = keras.layers.TimeDistributed(keras.layers.BatchNormalization()) + _ = layer(x) + assert len(layer.updates) == 2 + assert len(layer.trainable_weights) == 2 + layer.trainable = False + assert not layer.updates + assert not layer.trainable_weights + layer.trainable = True + assert len(layer.updates) == 2 + assert len(layer.trainable_weights) == 2 + class BidirectionalTest(test.TestCase): @@ -153,7 +165,7 @@ class BidirectionalTest(test.TestCase): model.add( keras.layers.Bidirectional( rnn(output_dim), merge_mode=mode, input_shape=(timesteps, dim))) - model.compile(loss='mse', optimizer='sgd') + model.compile(optimizer=RMSPropOptimizer(0.01), loss='mse') model.fit(x, y, epochs=1, batch_size=1) # test compute output shape @@ -338,23 +350,38 @@ class BidirectionalTest(test.TestCase): units = 3 with self.test_session(): - inputs = keras.Input((timesteps, dim)) + input1 = keras.layers.Input((timesteps, dim)) layer = keras.layers.Bidirectional( rnn(units, return_state=True, return_sequences=True)) - outputs = layer(inputs) - output, state = outputs[0], outputs[1:] + state = layer(input1)[1:] # test passing invalid initial_state: passing a tensor + input2 = keras.layers.Input((timesteps, dim)) with self.assertRaises(ValueError): output = keras.layers.Bidirectional( - rnn(units))(output, initial_state=state[0]) + rnn(units))(input2, initial_state=state[0]) # test valid usage: passing a list - output = keras.layers.Bidirectional( - rnn(units))(output, initial_state=state) - model = keras.Model(inputs, output) - inputs = np.random.rand(samples, timesteps, dim) - outputs = model.predict(inputs) + output = keras.layers.Bidirectional(rnn(units))(input2, + initial_state=state) + model = keras.models.Model([input1, input2], output) + assert len(model.layers) == 4 + assert isinstance(model.layers[-1].input, list) + inputs = [np.random.rand(samples, timesteps, dim), + np.random.rand(samples, timesteps, dim)] + model.predict(inputs) + + def test_Bidirectional_trainable(self): + # test layers that need learning_phase to be set + with self.test_session(): + x = keras.layers.Input(shape=(3, 2)) + layer = keras.layers.Bidirectional(keras.layers.SimpleRNN(3)) + _ = layer(x) + assert len(layer.trainable_weights) == 6 + layer.trainable = False + assert not layer.trainable_weights + layer.trainable = True + assert len(layer.trainable_weights) == 6 def _to_list(ls): diff --git a/tensorflow/python/keras/_impl/keras/metrics.py b/tensorflow/python/keras/_impl/keras/metrics.py index 0e2fb6365a2d9fda987d1326d8a48f40b55672f4..82778a3dc4fbdc13bb6682d01e28ff68882b6dd9 100644 --- a/tensorflow/python/keras/_impl/keras/metrics.py +++ b/tensorflow/python/keras/_impl/keras/metrics.py @@ -36,6 +36,7 @@ from tensorflow.python.keras._impl.keras.losses import poisson from tensorflow.python.keras._impl.keras.losses import sparse_categorical_crossentropy from tensorflow.python.keras._impl.keras.losses import squared_hinge from tensorflow.python.keras._impl.keras.utils.generic_utils import deserialize_keras_object +from tensorflow.python.keras._impl.keras.utils.generic_utils import serialize_keras_object from tensorflow.python.util.tf_export import tf_export @@ -79,13 +80,13 @@ cosine = cosine_proximity @tf_export('keras.metrics.serialize') def serialize(metric): - return metric.__name__ + return serialize_keras_object(metric) @tf_export('keras.metrics.deserialize') -def deserialize(name, custom_objects=None): +def deserialize(config, custom_objects=None): return deserialize_keras_object( - name, + config, module_objects=globals(), custom_objects=custom_objects, printable_module_name='metric function') @@ -93,11 +94,13 @@ def deserialize(name, custom_objects=None): @tf_export('keras.metrics.get') def get(identifier): - if isinstance(identifier, six.string_types): - identifier = str(identifier) - return deserialize(identifier) + if isinstance(identifier, dict): + config = {'class_name': str(identifier), 'config': {}} + return deserialize(config) + elif isinstance(identifier, six.string_types): + return deserialize(str(identifier)) elif callable(identifier): return identifier else: raise ValueError('Could not interpret ' - 'metric function identifier:', identifier) + 'metric function identifier: %s' % identifier) diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py index f4792f3543cc5ca8e5e7ad03d9906bbfadd1fb04..44289ea02abf5ae5f8befbe515552aea3d4b231e 100644 --- a/tensorflow/python/keras/_impl/keras/metrics_test.py +++ b/tensorflow/python/keras/_impl/keras/metrics_test.py @@ -72,6 +72,77 @@ class KerasMetricsTest(test.TestCase): keras.metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) self.assertEqual(result, 0.) + def test_stateful_metrics(self): + np.random.seed(1334) + + class BinaryTruePositives(keras.layers.Layer): + """Stateful Metric to count the total true positives over all batches. + + Assumes predictions and targets of shape `(samples, 1)`. + + Arguments: + threshold: Float, lower limit on prediction value that counts as a + positive class prediction. + name: String, name for the metric. + """ + + def __init__(self, name='true_positives', **kwargs): + super(BinaryTruePositives, self).__init__(name=name, **kwargs) + self.true_positives = keras.backend.variable(value=0, dtype='int32') + + def reset_states(self): + keras.backend.set_value(self.true_positives, 0) + + def __call__(self, y_true, y_pred): + """Computes the number of true positives in a batch. + + Args: + y_true: Tensor, batch_wise labels + y_pred: Tensor, batch_wise predictions + + Returns: + The total number of true positives seen this epoch at the + completion of the batch. + """ + y_true = keras.backend.cast(y_true, 'int32') + y_pred = keras.backend.cast(keras.backend.round(y_pred), 'int32') + correct_preds = keras.backend.cast( + keras.backend.equal(y_pred, y_true), 'int32') + true_pos = keras.backend.cast( + keras.backend.sum(correct_preds * y_true), 'int32') + current_true_pos = self.true_positives * 1 + self.add_update(keras.backend.update_add(self.true_positives, + true_pos), + inputs=[y_true, y_pred]) + return current_true_pos + true_pos + + metric_fn = BinaryTruePositives() + config = keras.metrics.serialize(metric_fn) + metric_fn = keras.metrics.deserialize( + config, custom_objects={'BinaryTruePositives': BinaryTruePositives}) + + # Test on simple model + inputs = keras.Input(shape=(2,)) + outputs = keras.layers.Dense(1, activation='sigmoid')(inputs) + model = keras.Model(inputs, outputs) + model.compile(optimizer='sgd', + loss='binary_crossentropy', + metrics=['acc', metric_fn]) + + # Test fit, evaluate + samples = 1000 + x = np.random.random((samples, 2)) + y = np.random.randint(2, size=(samples, 1)) + model.fit(x, y, epochs=1, batch_size=10) + outs = model.evaluate(x, y, batch_size=10) + preds = model.predict(x) + + def ref_true_pos(y_true, y_pred): + return np.sum(np.logical_and(y_pred > 0.5, y_true == 1)) + + # Test correctness (e.g. updates should have been run) + self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3d71a620fcb34d21c41f920eed99b1fe22668899 --- /dev/null +++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py @@ -0,0 +1,589 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Model subclassing.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile + +import numpy as np + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.keras._impl import keras +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test +from tensorflow.python.training.rmsprop import RMSPropOptimizer + +try: + import h5py # pylint:disable=g-import-not-at-top +except ImportError: + h5py = None + + +class SimpleTestModel(keras.Model): + + def __init__(self, use_bn=False, use_dp=False, num_classes=10): + super(SimpleTestModel, self).__init__(name='test_model') + self.use_bn = use_bn + self.use_dp = use_dp + self.num_classes = num_classes + + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes, activation='softmax') + if self.use_dp: + self.dp = keras.layers.Dropout(0.5) + if self.use_bn: + self.bn = keras.layers.BatchNormalization(axis=-1) + + def call(self, inputs): + x = self.dense1(inputs) + if self.use_dp: + x = self.dp(x) + if self.use_bn: + x = self.bn(x) + return self.dense2(x) + + +class MultiIOTestModel(keras.Model): + + def __init__(self, use_bn=False, use_dp=False, num_classes=(2, 3)): + super(MultiIOTestModel, self).__init__(name='test_model') + self.use_bn = use_bn + self.use_dp = use_dp + self.num_classes = num_classes + + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes[0], activation='softmax') + self.dense3 = keras.layers.Dense(num_classes[1], activation='softmax') + if use_dp: + self.dp = keras.layers.Dropout(0.5) + if use_bn: + self.bn = keras.layers.BatchNormalization() + + def call(self, inputs): + x1, x2 = inputs + x1 = self.dense1(x1) + x2 = self.dense1(x2) + if self.use_dp: + x1 = self.dp(x1) + if self.use_bn: + x2 = self.bn(x2) + return [self.dense2(x1), self.dense3(x2)] + + +class NestedTestModel1(keras.Model): + """A model subclass nested inside a model subclass. + """ + + def __init__(self, num_classes=2): + super(NestedTestModel1, self).__init__(name='nested_model_1') + self.num_classes = num_classes + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes, activation='relu') + self.bn = keras.layers.BatchNormalization() + self.test_net = SimpleTestModel(num_classes=4, + use_bn=True, + use_dp=True) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.bn(x) + x = self.test_net(x) # pylint: disable=not-callable + return self.dense2(x) + + +def get_functional_graph_model(input_dim, num_classes): + # A simple functional-API model (a.k.a. graph network) + inputs = keras.Input(shape=(input_dim,)) + x = keras.layers.Dense(32, activation='relu')(inputs) + x = keras.layers.BatchNormalization()(x) + outputs = keras.layers.Dense(num_classes)(x) + return keras.Model(inputs, outputs) + + +class NestedTestModel2(keras.Model): + """A model subclass with a functional-API graph network inside. + """ + + def __init__(self, num_classes=2): + super(NestedTestModel2, self).__init__(name='nested_model_2') + self.num_classes = num_classes + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(num_classes, activation='relu') + self.bn = self.bn = keras.layers.BatchNormalization() + self.test_net = get_functional_graph_model(32, 4) + + def call(self, inputs): + x = self.dense1(inputs) + x = self.bn(x) + x = self.test_net(x) + return self.dense2(x) + + +def get_nested_model_3(input_dim, num_classes): + # A functional-API model with a subclassed model inside. + # NOTE: this requires the inner subclass to implement `compute_output_shape`. + + inputs = keras.Input(shape=(input_dim,)) + x = keras.layers.Dense(32, activation='relu')(inputs) + x = keras.layers.BatchNormalization()(x) + + class Inner(keras.Model): + + def __init__(self): + super(Inner, self).__init__() + self.dense1 = keras.layers.Dense(32, activation='relu') + self.dense2 = keras.layers.Dense(5, activation='relu') + self.bn = keras.layers.BatchNormalization() + + def call(self, inputs): + x = self.dense1(inputs) + x = self.dense2(x) + return self.bn(x) + + def compute_output_shape(self, input_shape): + return tensor_shape.TensorShape((input_shape[0], 5)) + + test_model = Inner() + x = test_model(x) # pylint: disable=not-callable + outputs = keras.layers.Dense(num_classes)(x) + return keras.Model(inputs, outputs, name='nested_model_3') + + +class ModelSubclassingTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def test_single_io_workflow_with_np_arrays(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + @test_util.run_in_graph_and_eager_modes() + def test_multi_io_workflow_with_np_arrays(self): + num_classes = (2, 3) + num_samples = 1000 + input_dim = 50 + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + _ = model.evaluate([x1, x2], [y1, y2], verbose=0) + + def test_single_io_workflow_with_tensors(self): + + num_classes = 2 + num_samples = 10 + input_dim = 50 + + with self.test_session(): + model = SimpleTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x = array_ops.ones((num_samples, input_dim)) + y = array_ops.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, steps_per_epoch=10, verbose=0) + _ = model.evaluate(steps=10, verbose=0) + + def test_multi_io_workflow_with_tensors(self): + + num_classes = (2, 3) + num_samples = 10 + input_dim = 50 + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x1 = array_ops.ones((num_samples, input_dim)) + x2 = array_ops.ones((num_samples, input_dim)) + y1 = array_ops.zeros((num_samples, num_classes[0])) + y2 = array_ops.zeros((num_samples, num_classes[1])) + + model.fit([x1, x2], [y1, y2], epochs=2, steps_per_epoch=10, verbose=0) + _ = model.evaluate(steps=10, verbose=0) + + def test_multi_io_workflow_with_numpy_arrays_and_custom_placeholders(self): + + num_classes = (2, 3) + num_samples = 1000 + input_dim = 50 + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, + use_dp=True, + use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + x2_placeholder = array_ops.placeholder( + dtype='float32', shape=(None, input_dim)) + model._set_inputs([x1, x2_placeholder]) + + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + _ = model.evaluate([x1, x2], [y1, y2], verbose=0) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_attributes(self): + # layers, weights, trainable_weights, non_trainable_weights, inputs, outputs + + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + self.assertEqual(model.name, 'test_model') + self.assertEqual(model.built, False) + self.assertEqual(len(model.weights), 0) + + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.train_on_batch([x1, x2], [y1, y2]) + + self.assertEqual(model.built, True) + self.assertEqual(len(model.layers), 4) + self.assertEqual(len(model.weights), 10) + self.assertEqual(len(model.trainable_weights), 8) + self.assertEqual(len(model.non_trainable_weights), 2) + self.assertEqual(len(model.inputs), 2) + self.assertEqual(len(model.outputs), 2) + + @test_util.run_in_graph_and_eager_modes() + def test_updates(self): + # test that updates get run during training + num_samples = 100 + input_dim = 50 + + class BNNet(keras.Model): + + def __init__(self): + super(BNNet, self).__init__() + self.bn = keras.layers.BatchNormalization(beta_initializer='ones', + gamma_initializer='ones') + + def call(self, inputs): + return self.bn(inputs) + + x = np.ones((num_samples, input_dim)) + y = np.ones((num_samples, input_dim)) + + with self.test_session(): + model = BNNet() + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + y_ref = model.predict(x) + + model.train_on_batch(x, y) + y_new = model.predict(x) + self.assertGreater(np.sum(np.abs(y_ref - y_new)), 0.1) + + @test_util.run_in_graph_and_eager_modes() + def test_training_and_inference_behavior(self): + # test that dropout is applied in training and not inference + + num_samples = 100 + input_dim = 50 + + class DPNet(keras.Model): + + def __init__(self): + super(DPNet, self).__init__() + self.dp = keras.layers.Dropout(0.5) + self.dense = keras.layers.Dense(1, + use_bias=False, + kernel_initializer='ones') + + def call(self, inputs): + x = self.dp(inputs) + return self.dense(x) + + with self.test_session(): + model = DPNet() + x = np.ones((num_samples, input_dim)) + y = model.predict(x) + self.assertEqual(np.sum(y), np.sum(x)) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + loss = model.train_on_batch(x, y) + self.assertGreater(loss, 0.1) + + @test_util.run_in_graph_and_eager_modes() + def test_training_methods(self): + # test fit, train_on_batch + # on different input types: list, dict + + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + model.fit({'input_1': x1, 'input_2': x2}, + {'output_1': y1, 'output_2': y2}, + epochs=2, batch_size=32) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0, + validation_data=([x1, x2], [y1, y2])) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.train_on_batch([x1, x2], [y1, y2]) + model.train_on_batch({'input_1': x1, 'input_2': x2}, + {'output_1': y1, 'output_2': y2}) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def test_inference_methods(self): + # test predict, evaluate, test_on_batch, predict_on_batch + # on different input types: list, dict + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.evaluate([x1, x2], [y1, y2]) + model.test_on_batch([x1, x2], [y1, y2]) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.predict([x1, x2]) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.predict_on_batch([x1, x2]) + + @test_util.run_in_graph_and_eager_modes() + def test_trainable_mutation(self): + # test that you can change `trainable` on a model or layer, and that + # it freezes the model state during training + # TODO(fchollet): add test after we unify BN behavior in eager and symbolic. + pass + + @test_util.run_in_graph_and_eager_modes() + def test_saving(self): + if h5py is None: + return # Skip test if models cannot be saved. + + num_classes = (2, 3) + num_samples = 100 + input_dim = 50 + + x1 = np.ones((num_samples, input_dim)) + x2 = np.ones((num_samples, input_dim)) + y1 = np.zeros((num_samples, num_classes[0])) + y2 = np.zeros((num_samples, num_classes[1])) + + with self.test_session(): + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0) + y_ref_1, y_ref_2 = model.predict([x1, x2]) + + fd, fname = tempfile.mkstemp('.h5') + model.save_weights(fname) + + model = MultiIOTestModel(num_classes=num_classes, use_bn=True) + # need to build the model before loading weights + # (otherwise no weights to load) + model._set_inputs([x1, x2]) + model.load_weights(fname) + + y1, y2 = model.predict([x1, x2]) + self.assertAllClose(y_ref_1, y1, atol=1e-5) + self.assertAllClose(y_ref_2, y2, atol=1e-5) + os.close(fd) + os.remove(fname) + + @test_util.run_in_graph_and_eager_modes() + def test_summary(self): + + class ToString(object): + + def __init__(self): + self.contents = '' + + def __call__(self, msg): + self.contents += msg + '\n' + + # Single-io + model = SimpleTestModel(num_classes=4, use_bn=True, use_dp=True) + model._set_inputs(np.ones((3, 4))) # need to build model first + print_fn = ToString() + model.summary(print_fn=print_fn) + self.assertTrue('Trainable params: 356' in print_fn.contents) + + # Multi-io + model = MultiIOTestModel(num_classes=(5, 6), use_bn=True, use_dp=True) + model._set_inputs([np.ones((3, 4)), + np.ones((3, 4))]) # need to build model first + print_fn = ToString() + model.summary(print_fn=print_fn) + self.assertTrue('Trainable params: 587' in print_fn.contents) + + @test_util.run_in_graph_and_eager_modes() + def test_subclass_nested_in_subclass(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = NestedTestModel1(num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) + self.assertEqual(len(model.non_trainable_weights), + 2 + len(model.test_net.non_trainable_weights)) + self.assertEqual(len(model.trainable_weights), + 6 + len(model.test_net.trainable_weights)) + + @test_util.run_in_graph_and_eager_modes() + def test_graph_nested_in_subclass(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = NestedTestModel2(num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + self.assertEqual(len(model.weights), 8 + len(model.test_net.weights)) + self.assertEqual(len(model.non_trainable_weights), + 2 + len(model.test_net.non_trainable_weights)) + self.assertEqual(len(model.trainable_weights), + 6 + len(model.test_net.trainable_weights)) + + @test_util.run_in_graph_and_eager_modes() + def test_subclass_nested_in_graph(self): + num_classes = 2 + num_samples = 100 + input_dim = 50 + + with self.test_session(): + model = get_nested_model_3(input_dim=input_dim, num_classes=num_classes) + model.compile(loss='mse', + optimizer=RMSPropOptimizer(learning_rate=0.001), + metrics=['acc']) + + x = np.ones((num_samples, input_dim)) + y = np.zeros((num_samples, num_classes)) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + _ = model.evaluate(x, y, verbose=0) + + self.assertEqual(len(model.weights), 16) + self.assertEqual( + len(model.non_trainable_weights), 4) + self.assertEqual(len(model.trainable_weights), 12) + + @test_util.run_in_graph_and_eager_modes() + def test_support_for_manual_training_arg(self): + # In most cases, the `training` argument is left unspecified, in which + # case it defaults to value corresponding to the Model method being used + # (fit -> True, predict -> False, etc). + # If the user writes their model `call` method to take + # an explicit `training` argument, we must check that the correct value + # is being passed to the model for each method call. + + class DPNet(keras.Model): + + def __init__(self): + super(DPNet, self).__init__() + self.dp = keras.layers.Dropout(0.5) + self.dense = keras.layers.Dense(1, + use_bias=False, + kernel_initializer='ones') + + def call(self, inputs, training=False): + x = self.dp(inputs, training=training) + return self.dense(x) + + with self.test_session(): + model = DPNet() + x = np.ones((10, 10)) + y = model.predict(x) + self.assertEqual(np.sum(y), np.sum(x)) + model.compile(loss='mse', optimizer=RMSPropOptimizer(learning_rate=0.001)) + loss = model.train_on_batch(x, y) + self.assertGreater(loss, 0.1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/python/keras/_impl/keras/models.py b/tensorflow/python/keras/_impl/keras/models.py index f5d44ef66916c82d131d307d4ca1ed91b377ccc2..9602e7ba39b290f33c7ca9d0d1b5b35838667531 100644 --- a/tensorflow/python/keras/_impl/keras/models.py +++ b/tensorflow/python/keras/_impl/keras/models.py @@ -13,1303 +13,30 @@ # limitations under the License. # ============================================================================== # pylint: disable=protected-access -"""Home of the Sequential model, and the `save_model`/`load_model` functions. +"""Code for model cloning, plus model-related API entries. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import copy -import json -import os - -import numpy as np - -from tensorflow.python.framework import ops from tensorflow.python.keras._impl.keras import backend as K -from tensorflow.python.keras._impl.keras import layers as layer_module -from tensorflow.python.keras._impl.keras import optimizers -from tensorflow.python.keras._impl.keras.engine import topology -from tensorflow.python.keras._impl.keras.engine.topology import Input -from tensorflow.python.keras._impl.keras.engine.topology import InputLayer -from tensorflow.python.keras._impl.keras.engine.topology import Layer -from tensorflow.python.keras._impl.keras.engine.topology import TFBaseLayer -from tensorflow.python.keras._impl.keras.engine.training import Model +from tensorflow.python.keras._impl.keras.engine import saving +from tensorflow.python.keras._impl.keras.engine import sequential +from tensorflow.python.keras._impl.keras.engine import training +from tensorflow.python.keras._impl.keras.engine.input_layer import Input +from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer +from tensorflow.python.keras._impl.keras.utils import generic_utils from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg -from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util.tf_export import tf_export - - -# pylint: disable=g-import-not-at-top -try: - import h5py -except ImportError: - h5py = None - -try: - import yaml -except ImportError: - yaml = None -# pylint: enable=g-import-not-at-top - - -@tf_export('keras.models.save_model') -def save_model(model, filepath, overwrite=True, include_optimizer=True): - """Save a model to a HDF5 file. - - The saved model contains: - - the model's configuration (topology) - - the model's weights - - the model's optimizer's state (if any) - - Thus the saved model can be reinstantiated in - the exact same state, without any of the code - used for model definition or training. - - Arguments: - model: Keras model instance to be saved. - filepath: String, path where to save the model. - overwrite: Whether we should overwrite any existing - model at the target location, or instead - ask the user with a manual prompt. - include_optimizer: If True, save optimizer's state together. - - Raises: - ImportError: if h5py is not available. - """ - - if h5py is None: - raise ImportError('`save_model` requires h5py.') - - def get_json_type(obj): - """Serialize any object to a JSON-serializable structure. - - Arguments: - obj: the object to serialize - - Returns: - JSON-serializable structure representing `obj`. - - Raises: - TypeError: if `obj` cannot be serialized. - """ - # if obj is a serializable Keras class instance - # e.g. optimizer, layer - if hasattr(obj, 'get_config'): - return {'class_name': obj.__class__.__name__, 'config': obj.get_config()} - - # if obj is any numpy type - if type(obj).__module__ == np.__name__: - if isinstance(obj, np.ndarray): - return {'type': type(obj), 'value': obj.tolist()} - else: - return obj.item() - - # misc functions (e.g. loss function) - if callable(obj): - return obj.__name__ - - # if obj is a python 'type' - if type(obj).__name__ == type.__name__: - return obj.__name__ - - raise TypeError('Not JSON Serializable:', obj) - - from tensorflow.python.keras._impl.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top - - # If file exists and should not be overwritten. - if not overwrite and os.path.isfile(filepath): - proceed = ask_to_proceed_with_overwrite(filepath) - if not proceed: - return - - with h5py.File(filepath, mode='w') as f: - f.attrs['keras_version'] = str(keras_version).encode('utf8') - f.attrs['backend'] = K.backend().encode('utf8') - f.attrs['model_config'] = json.dumps( - { - 'class_name': model.__class__.__name__, - 'config': model.get_config() - }, - default=get_json_type).encode('utf8') - - model_weights_group = f.create_group('model_weights') - model_layers = model.layers - topology.save_weights_to_hdf5_group(model_weights_group, model_layers) - - if include_optimizer and hasattr(model, 'optimizer'): - if isinstance(model.optimizer, optimizers.TFOptimizer): - logging.warning( - 'TensorFlow optimizers do not ' - 'make it possible to access ' - 'optimizer attributes or optimizer state ' - 'after instantiation. ' - 'As a result, we cannot save the optimizer ' - 'as part of the model save file.' - 'You will have to compile your model again after loading it. ' - 'Prefer using a Keras optimizer instead ' - '(see keras.io/optimizers).') - else: - f.attrs['training_config'] = json.dumps( - { - 'optimizer_config': { - 'class_name': model.optimizer.__class__.__name__, - 'config': model.optimizer.get_config() - }, - 'loss': model.loss, - 'metrics': model.metrics, - 'sample_weight_mode': model.sample_weight_mode, - 'loss_weights': model.loss_weights, - }, - default=get_json_type).encode('utf8') - - # Save optimizer weights. - symbolic_weights = getattr(model.optimizer, 'weights') - if symbolic_weights: - optimizer_weights_group = f.create_group('optimizer_weights') - weight_values = K.batch_get_value(symbolic_weights) - weight_names = [] - for w, val in zip(symbolic_weights, weight_values): - name = str(w.name) - weight_names.append(name.encode('utf8')) - optimizer_weights_group.attrs['weight_names'] = weight_names - for name, val in zip(weight_names, weight_values): - param_dset = optimizer_weights_group.create_dataset( - name, val.shape, dtype=val.dtype) - if not val.shape: - # scalar - param_dset[()] = val - else: - param_dset[:] = val - f.flush() - - -@tf_export('keras.models.load_model') -def load_model(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin - """Loads a model saved via `save_model`. - - Arguments: - filepath: String, path to the saved model. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - compile: Boolean, whether to compile the model - after loading. - - Returns: - A Keras model instance. If an optimizer was found - as part of the saved model, the model is already - compiled. Otherwise, the model is uncompiled and - a warning will be displayed. When `compile` is set - to False, the compilation is omitted without any - warning. - - Raises: - ImportError: if h5py is not available. - ValueError: In case of an invalid savefile. - """ - if h5py is None: - raise ImportError('`load_model` requires h5py.') - - if not custom_objects: - custom_objects = {} - - def convert_custom_objects(obj): - """Handles custom object lookup. - - Arguments: - obj: object, dict, or list. - - Returns: - The same structure, where occurrences - of a custom object name have been replaced - with the custom object. - """ - if isinstance(obj, list): - deserialized = [] - for value in obj: - deserialized.append(convert_custom_objects(value)) - return deserialized - if isinstance(obj, dict): - deserialized = {} - for key, value in obj.items(): - deserialized[key] = convert_custom_objects(value) - return deserialized - if obj in custom_objects: - return custom_objects[obj] - return obj - - with h5py.File(filepath, mode='r') as f: - # instantiate model - model_config = f.attrs.get('model_config') - if model_config is None: - raise ValueError('No model found in config file.') - model_config = json.loads(model_config.decode('utf-8')) - model = model_from_config(model_config, custom_objects=custom_objects) - - # set weights - topology.load_weights_from_hdf5_group(f['model_weights'], model.layers) - - # Early return if compilation is not required. - if not compile: - return model - - # instantiate optimizer - training_config = f.attrs.get('training_config') - if training_config is None: - logging.warning('No training configuration found in save file: ' - 'the model was *not* compiled. Compile it manually.') - return model - training_config = json.loads(training_config.decode('utf-8')) - optimizer_config = training_config['optimizer_config'] - optimizer = optimizers.deserialize( - optimizer_config, custom_objects=custom_objects) - - # Recover loss functions and metrics. - loss = convert_custom_objects(training_config['loss']) - metrics = convert_custom_objects(training_config['metrics']) - sample_weight_mode = training_config['sample_weight_mode'] - loss_weights = training_config['loss_weights'] - - # Compile model. - model.compile( - optimizer=optimizer, - loss=loss, - metrics=metrics, - loss_weights=loss_weights, - sample_weight_mode=sample_weight_mode) - - # Set optimizer weights. - if 'optimizer_weights' in f: - # Build train function (to get weight updates). - if isinstance(model, Sequential): - model.model._make_train_function() - else: - model._make_train_function() - optimizer_weights_group = f['optimizer_weights'] - optimizer_weight_names = [ - n.decode('utf8') - for n in optimizer_weights_group.attrs['weight_names'] - ] - optimizer_weight_values = [ - optimizer_weights_group[n] for n in optimizer_weight_names - ] - try: - model.optimizer.set_weights(optimizer_weight_values) - except ValueError: - logging.warning('Error in loading the saved optimizer ' - 'state. As a result, your model is ' - 'starting with a freshly initialized ' - 'optimizer.') - return model - - -@tf_export('keras.models.model_from_config') -def model_from_config(config, custom_objects=None): - """Instantiates a Keras model from its config. - - Arguments: - config: Configuration dictionary. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A Keras model instance (uncompiled). - - Raises: - TypeError: if `config` is not a dictionary. - """ - if isinstance(config, list): - raise TypeError('`model_from_config` expects a dictionary, not a list. ' - 'Maybe you meant to use ' - '`Sequential.from_config(config)`?') - return layer_module.deserialize(config, custom_objects=custom_objects) - - -@tf_export('keras.models.model_from_yaml') -def model_from_yaml(yaml_string, custom_objects=None): - """Parses a yaml model configuration file and returns a model instance. - - Arguments: - yaml_string: YAML string encoding a model configuration. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A Keras model instance (uncompiled). - - Raises: - ImportError: if yaml module is not found. - """ - if yaml is None: - raise ImportError('Requires yaml module installed.') - config = yaml.load(yaml_string) - return layer_module.deserialize(config, custom_objects=custom_objects) - - -@tf_export('keras.models.model_from_json') -def model_from_json(json_string, custom_objects=None): - """Parses a JSON model configuration file and returns a model instance. - - Arguments: - json_string: JSON string encoding a model configuration. - custom_objects: Optional dictionary mapping names - (strings) to custom classes or functions to be - considered during deserialization. - - Returns: - A Keras model instance (uncompiled). - """ - config = json.loads(json_string) - return layer_module.deserialize(config, custom_objects=custom_objects) - - -@tf_export('keras.models.Sequential', 'keras.Sequential') -class Sequential(Model): - """Linear stack of layers. - - Arguments: - layers: list of layers to add to the model. - - # Note - The first layer passed to a Sequential model - should have a defined input shape. What that - means is that it should have received an `input_shape` - or `batch_input_shape` argument, - or for some type of layers (recurrent, Dense...) - an `input_dim` argument. - - Example: - - ```python - model = Sequential() - # first layer must have a defined input shape - model.add(Dense(32, input_dim=500)) - # afterwards, Keras does automatic shape inference - model.add(Dense(32)) - - # also possible (equivalent to the above): - model = Sequential() - model.add(Dense(32, input_shape=(500,))) - model.add(Dense(32)) - - # also possible (equivalent to the above): - model = Sequential() - # here the batch dimension is None, - # which means any batch size will be accepted by the model. - model.add(Dense(32, batch_input_shape=(None, 500))) - model.add(Dense(32)) - ``` - """ - - def __init__(self, layers=None, name=None): - self.layers = [] # Stack of layers. - self.model = None # Internal Model instance. - self.inputs = [] # List of input tensors - self.outputs = [] # List of length 1: the output tensor (unique). - self._trainable = True - self._initial_weights = None - self._input_layers = [] - - # Model attributes. - self._inbound_nodes = [] - self._outbound_nodes = [] - self.built = False - - # Set model name. - if not name: - prefix = 'sequential_' - name = prefix + str(K.get_uid(prefix)) - self._name = name - - # Used by Layer base class. - self._dtype = None - self._activity_regularizer = None - - # The following properties are not actually used by Keras; - # they exist for compatibility with TF's variable scoping mechanism. - self._updates = [] - self._losses = [] - self._scope = None - self._reuse = None - self._base_name = name - self._graph = ops.get_default_graph() - - # Add to the model any layers passed to the constructor. - if layers: - for layer in layers: - self.add(layer) - - def add(self, layer): - """Adds a layer instance on top of the layer stack. - - Arguments: - layer: layer instance. - - Raises: - TypeError: If `layer` is not a layer instance. - ValueError: In case the `layer` argument does not - know its input shape. - ValueError: In case the `layer` argument has - multiple output tensors, or is already connected - somewhere else (forbidden in `Sequential` models). - """ - if not isinstance(layer, (Layer, TFBaseLayer)): - raise TypeError('The added layer must be ' - 'an instance of class Layer. ' - 'Found: ' + str(layer)) - if not self.outputs: - # First layer in model: check that it is an input layer. - if not isinstance(layer, InputLayer): - # Create an input layer. - # First, we need to infer its expected input shape and dtype. - if isinstance(layer, (Model, Sequential)): - # We were passed a model as first layer. - # This requires a specific way to figure out the - # input shape and dtype. - if not layer.layers: - raise ValueError('Cannot add an empty model ' - 'to a `Sequential` model.') - # In case of nested models: recover the first layer - # of the deepest model to infer input shape and dtype. - first_layer = layer.layers[0] - while isinstance(first_layer, (Model, Sequential)): - first_layer = first_layer.layers[0] - batch_shape = first_layer._batch_input_shape - dtype = first_layer.dtype - else: - # We were passed a regular layer, and it should - # know about its input shape. Otherwise, that's an error. - if not hasattr(layer, '_batch_input_shape'): - raise ValueError('The first layer in a ' - 'Sequential model must ' - 'get an `input_shape` argument.') - batch_shape = layer._batch_input_shape - dtype = layer.dtype - # Instantiate the input layer. - x = Input( - batch_shape=batch_shape, dtype=dtype, name=layer.name + '_input') - # This will build the current layer - # and create the node connecting the current layer - # to the input layer we just created. - layer(x) - - if len(layer._inbound_nodes[-1].output_tensors) != 1: - raise ValueError('All layers in a Sequential model ' - 'should have a single output tensor. ' - 'For multi-output layers, ' - 'use the functional API.') - - self.outputs = [layer._inbound_nodes[-1].output_tensors[0]] - self.inputs = topology.get_source_inputs(self.outputs[0]) - - # We create an input node, which we will keep updated - # as we add more layers - topology.Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self.inputs, - output_tensors=self.outputs) - else: - output_tensor = layer(self.outputs[0]) - if isinstance(output_tensor, list): - raise TypeError('All layers in a Sequential model ' - 'should have a single output tensor. ' - 'For multi-output layers, ' - 'use the functional API.') - self.outputs = [output_tensor] - # update self._inbound_nodes - self._inbound_nodes[0].output_tensors = self.outputs - self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] - - self.layers.append(layer) - self.built = False - - def pop(self): - """Removes the last layer in the model. - - Raises: - TypeError: if there are no layers in the model. - """ - if not self.layers: - raise TypeError('There are no layers in the model.') - - self.layers.pop() - if not self.layers: - self.outputs = [] - self._inbound_nodes = [] - self._outbound_nodes = [] - else: - self.layers[-1]._outbound_nodes = [] - self.outputs = [self.layers[-1].output] - # update self._inbound_nodes - self._inbound_nodes[0].output_tensors = self.outputs - self._inbound_nodes[0].output_shapes = [K.int_shape(self.outputs[0])] - self.built = False - - def get_layer(self, name=None, index=None): - """Retrieve a layer that is part of the model. - - Returns a layer based on either its name (unique) - or its index in the graph. Indices are based on - order of horizontal graph traversal (bottom-up). - - Arguments: - name: string, name of layer. - index: integer, index of layer. - - Returns: - A layer instance. - """ - if not self.built: - self.build() - return self.model.get_layer(name, index) - - def call(self, inputs, mask=None): - if not self.built: - self.build() - return self.model.call(inputs, mask) - - def build(self, input_shape=None): - if not self.inputs or not self.outputs: - raise TypeError('Sequential model cannot be built: model is empty.' - ' Add some layers first.') - # actually create the model - self.model = Model(self.inputs, self.outputs[0], name=self.name + '_model') - self.model.trainable = self.trainable - - # mirror model attributes - self.supports_masking = self.model.supports_masking - self._output_mask_cache = self.model._output_mask_cache - self._output_tensor_cache = self.model._output_tensor_cache - self._output_shape_cache = self.model._output_shape_cache - self._input_layers = self.model._input_layers - self._output_layers = self.model._output_layers - self._input_coordinates = self.model._input_coordinates - self._output_coordinates = self.model._output_coordinates - self._nodes_by_depth = self.model._nodes_by_depth - self._network_nodes = self.model._network_nodes - self.output_names = self.model.output_names - self.input_names = self.model.input_names - self._feed_input_names = self.model._feed_input_names - self._feed_inputs = self.model._feed_inputs - - # Make sure child model callbacks - # will call the parent Sequential model. - self.model.callback_model = self - - self.built = True - - @property - def uses_learning_phase(self): - if not self.built: - self.build() - return self.model.uses_learning_phase - - def _gather_list_attr(self, attr): - all_attrs = [] - for layer in self.layers: - all_attrs += getattr(layer, attr, []) - return all_attrs - - @property - def trainable(self): - return self._trainable - - @trainable.setter - def trainable(self, value): - if self.model: - self.model.trainable = value - self._trainable = value - - @property - def trainable_weights(self): - if not self.trainable: - return [] - return self._gather_list_attr('trainable_weights') - - @property - def non_trainable_weights(self): - weights = self._gather_list_attr('non_trainable_weights') - if not self.trainable: - trainable_weights = self._gather_list_attr('trainable_weights') - return trainable_weights + weights - return weights - - @property - def regularizers(self): - if not self.built: - self.build() - return self.model.regularizers - - def get_weights(self): - """Retrieves the weights of the model. - - Returns: - A flat list of Numpy arrays - (one array per model weight). - """ - if not self.built: - self.build() - return self.model.get_weights() - - def set_weights(self, weights): - """Sets the weights of the model. - - Arguments: - weights: Should be a list - of Numpy arrays with shapes and types matching - the output of `model.get_weights()`. - """ - if not self.built: - self.build() - self.model.set_weights(weights) - - def load_weights(self, filepath, by_name=False): - if h5py is None: - raise ImportError('`load_weights` requires h5py.') - f = h5py.File(filepath, mode='r') - if 'layer_names' not in f.attrs and 'model_weights' in f: - f = f['model_weights'] - layers = self.layers - if by_name: - topology.load_weights_from_hdf5_group_by_name(f, layers) - else: - topology.load_weights_from_hdf5_group(f, layers) - if hasattr(f, 'close'): - f.close() - - def save_weights(self, filepath, overwrite=True): - if h5py is None: - raise ImportError('`save_weights` requires h5py.') - # If file exists and should not be overwritten: - if not overwrite and os.path.isfile(filepath): - proceed = ask_to_proceed_with_overwrite(filepath) - if not proceed: - return - layers = self.layers - f = h5py.File(filepath, 'w') - topology.save_weights_to_hdf5_group(f, layers) - f.flush() - f.close() - - def compile(self, - optimizer, - loss, - metrics=None, - sample_weight_mode=None, - weighted_metrics=None, - target_tensors=None, - **kwargs): - """Configures the model for training. - - Arguments: - optimizer: String (name of optimizer) or optimizer object. - See [optimizers](/optimizers). - loss: String (name of objective function) or objective function. - See [losses](/losses). - If the model has multiple outputs, you can use a different loss - on each output by passing a dictionary or a list of losses. - The loss value that will be minimized by the model - will then be the sum of all individual losses. - metrics: List of metrics to be evaluated by the model - during training and testing. - Typically you will use `metrics=['accuracy']`. - To specify different metrics for different outputs of a - multi-output model, you could also pass a dictionary, - such as `metrics={'output_a': 'accuracy'}`. - sample_weight_mode: If you need to do timestep-wise - sample weighting (2D weights), set this to `"temporal"`. - `None` defaults to sample-wise weights (1D). - If the model has multiple outputs, you can use a different - `sample_weight_mode` on each output by passing a - dictionary or a list of modes. - weighted_metrics: list of metrics to be evaluated and weighted - by `sample_weight` or `class_weight` during training and testing. - target_tensors: By default, Keras will create a placeholder for the - model's target, which will be fed with the target data during - training. If instead you would like to use your own - target tensor (in turn, Keras will not expect external - Numpy data for these targets at training time), you - can specify them via the `target_tensors` argument. - It should be a single tensor - (for a single-output `Sequential` model). - **kwargs: These arguments are passed into `tf.Session.run`. - - Example: - ```python - model = Sequential() - model.add(Dense(32, input_shape=(500,))) - model.add(Dense(10, activation='softmax')) - model.compile(optimizer='rmsprop', - loss='categorical_crossentropy', - metrics=['accuracy']) - ``` - """ - # create the underlying model - self.build() - # call compile method of Model class - self.model.compile( - optimizer, - loss, - metrics=metrics, - sample_weight_mode=sample_weight_mode, - weighted_metrics=weighted_metrics, - target_tensors=target_tensors, - **kwargs) - self.optimizer = self.model.optimizer - self.loss = self.model.loss - self.metrics = self.model.metrics - self.loss_weights = self.model.loss_weights - self.sample_weight_mode = self.model.sample_weight_mode - self.weighted_metrics = self.model.weighted_metrics - self.targets = self.model.targets - self.metrics_tensors = self.model.metrics_tensors - self.metrics_names = self.model.metrics_names - self.sample_weights = self.model.sample_weights - self.total_loss = self.model.total_loss - - def fit(self, - x=None, - y=None, - batch_size=None, - epochs=1, - verbose=1, - callbacks=None, - validation_split=0., - validation_data=None, - shuffle=True, - class_weight=None, - sample_weight=None, - initial_epoch=0, - steps_per_epoch=None, - validation_steps=None, - **kwargs): - """Trains the model for a fixed number of epochs. - - Arguments: - x: Numpy array of training data. - If the input layer in the model is named, you can also pass a - dictionary mapping the input name to a Numpy array. - `x` can be `None` (default) if feeding from - TensorFlow data tensors. - y: Numpy array of target (label) data. - If the output layer in the model is named, you can also pass a - dictionary mapping the output name to a Numpy array. - `y` can be `None` (default) if feeding from - TensorFlow data tensors. - batch_size: Integer or `None`. - Number of samples per gradient update. - If unspecified, it will default to 32. - epochs: Integer. Number of epochs to train the model. - An epoch is an iteration over the entire `x` and `y` - data provided. - Note that in conjunction with `initial_epoch`, - `epochs` is to be understood as "final epoch". - The model is not trained for a number of iterations - given by `epochs`, but merely until the epoch - of index `epochs` is reached. - verbose: 0, 1, or 2. Verbosity mode. - 0 = silent, 1 = progress bar, 2 = one line per epoch. - callbacks: List of `keras.callbacks.Callback` instances. - List of callbacks to apply during training. - See [callbacks](/callbacks). - validation_split: Float between 0 and 1: - Fraction of the training data to be used as validation data. - The model will set apart this fraction of the training data, - will not train on it, and will evaluate - the loss and any model metrics - on this data at the end of each epoch. - The validation data is selected from the last samples - in the `x` and `y` data provided, before shuffling. - validation_data: tuple `(x_val, y_val)` or tuple - `(x_val, y_val, val_sample_weights)` on which to evaluate - the loss and any model metrics at the end of each epoch. - The model will not be trained on this data. - This will override `validation_split`. - shuffle: Boolean (whether to shuffle the training data - before each epoch) or str (for 'batch'). - 'batch' is a special option for dealing with the - limitations of HDF5 data; it shuffles in batch-sized chunks. - Has no effect when `steps_per_epoch` is not `None`. - class_weight: Optional dictionary mapping class indices (integers) - to a weight (float) value, used for weighting the loss function - (during training only). - This can be useful to tell the model to - "pay more attention" to samples from - an under-represented class. - sample_weight: Optional Numpy array of weights for - the training samples, used for weighting the loss function - (during training only). You can either pass a flat (1D) - Numpy array with the same length as the input samples - (1:1 mapping between weights and samples), - or in the case of temporal data, - you can pass a 2D array with shape - `(samples, sequence_length)`, - to apply a different weight to every timestep of every sample. - In this case you should make sure to specify - `sample_weight_mode="temporal"` in `compile()`. - initial_epoch: Epoch at which to start training - (useful for resuming a previous training run). - steps_per_epoch: Total number of steps (batches of samples) - before declaring one epoch finished and starting the - next epoch. When training with input tensors such as - TensorFlow data tensors, the default `None` is equal to - the number of unique samples in your dataset divided by - the batch size, or 1 if that cannot be determined. - validation_steps: Only relevant if `steps_per_epoch` - is specified. Total number of steps (batches of samples) - to validate before stopping. - **kwargs: Used for backwards compatibility support. - - Returns: - A `History` object. Its `History.history` attribute is - a record of training loss values and metrics values - at successive epochs, as well as validation loss values - and validation metrics values (if applicable). - - Raises: - RuntimeError: If the model was never compiled. - ValueError: In case of mismatch between the provided input data - and what the model expects. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.fit( - x, - y, - batch_size=batch_size, - epochs=epochs, - verbose=verbose, - callbacks=callbacks, - validation_split=validation_split, - validation_data=validation_data, - shuffle=shuffle, - class_weight=class_weight, - sample_weight=sample_weight, - initial_epoch=initial_epoch, - steps_per_epoch=steps_per_epoch, - validation_steps=validation_steps) - - def evaluate(self, x, y, batch_size=32, verbose=1, sample_weight=None): - """Computes the loss on some input data, batch by batch. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - batch_size: integer. Number of samples per gradient update. - verbose: verbosity mode, 0 or 1. - sample_weight: sample weights, as a Numpy array. - - Returns: - Scalar test loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.evaluate( - x, - y, - batch_size=batch_size, - verbose=verbose, - sample_weight=sample_weight) - - def predict(self, x, batch_size=32, verbose=0): - """Generates output predictions for the input samples. - - The input samples are processed batch by batch. - - Arguments: - x: the input data, as a Numpy array. - batch_size: integer. - verbose: verbosity mode, 0 or 1. - - Returns: - A Numpy array of predictions. - """ - if not self.built: - self.build() - return self.model.predict(x, batch_size=batch_size, verbose=verbose) - - def predict_on_batch(self, x): - """Returns predictions for a single batch of samples. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - - Returns: - A Numpy array of predictions. - """ - if not self.built: - self.build() - return self.model.predict_on_batch(x) - - def train_on_batch(self, x, y, class_weight=None, sample_weight=None): - """Single gradient update over one batch of samples. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - class_weight: dictionary mapping classes to a weight value, - used for scaling the loss function (during training only). - sample_weight: sample weights, as a Numpy array. - - Returns: - Scalar training loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.train_on_batch( - x, y, sample_weight=sample_weight, class_weight=class_weight) - - def test_on_batch(self, x, y, sample_weight=None): - """Evaluates the model over a single batch of samples. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - y: labels, as a Numpy array. - sample_weight: sample weights, as a Numpy array. - - Returns: - Scalar test loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - """ - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.test_on_batch(x, y, sample_weight=sample_weight) - - def predict_proba(self, x, batch_size=32, verbose=0): - """Generates class probability predictions for the input samples. - - The input samples are processed batch by batch. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - batch_size: integer. - verbose: verbosity mode, 0 or 1. - - Returns: - A Numpy array of probability predictions. - """ - preds = self.predict(x, batch_size, verbose) - if preds.min() < 0. or preds.max() > 1.: - logging.warning('Network returning invalid probability values. ' - 'The last layer might not normalize predictions ' - 'into probabilities ' - '(like softmax or sigmoid would).') - return preds - - def predict_classes(self, x, batch_size=32, verbose=0): - """Generate class predictions for the input samples. - - The input samples are processed batch by batch. - - Arguments: - x: input data, as a Numpy array or list of Numpy arrays - (if the model has multiple inputs). - batch_size: integer. - verbose: verbosity mode, 0 or 1. - - Returns: - A numpy array of class predictions. - """ - proba = self.predict(x, batch_size=batch_size, verbose=verbose) - if proba.shape[-1] > 1: - return proba.argmax(axis=-1) - else: - return (proba > 0.5).astype('int32') - - def fit_generator(self, - generator, - steps_per_epoch=None, - epochs=1, - verbose=1, - callbacks=None, - validation_data=None, - validation_steps=None, - class_weight=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - shuffle=True, - initial_epoch=0, - **kwargs): - """Fits the model on data generated batch-by-batch by a Python generator. - - The generator is run in parallel to the model, for efficiency. - For instance, this allows you to do real-time data augmentation - on images on CPU in parallel to training your model on GPU. - - Arguments: - generator: A generator. - The output of the generator must be either - - a tuple (inputs, targets) - - a tuple (inputs, targets, sample_weights). - All arrays should contain the same number of samples. - The generator is expected to loop over its data - indefinitely. An epoch finishes when `steps_per_epoch` - batches have been seen by the model. - steps_per_epoch: Total number of steps (batches of samples) - to yield from `generator` before declaring one epoch - finished and starting the next epoch. It should typically - be equal to the number of samples of your dataset - divided by the batch size. - Optional for `Sequence`: if unspecified, will use - the `len(generator)` as a number of steps. - epochs: Integer, total number of iterations on the data. - Note that in conjunction with initial_epoch, the parameter - epochs is to be understood as "final epoch". The model is - not trained for n steps given by epochs, but until the - epoch epochs is reached. - verbose: Verbosity mode, 0, 1, or 2. - callbacks: List of callbacks to be called during training. - validation_data: This can be either - - A generator for the validation data - - A tuple (inputs, targets) - - A tuple (inputs, targets, sample_weights). - validation_steps: Only relevant if `validation_data` - is a generator. - Number of steps to yield from validation generator - at the end of every epoch. It should typically - be equal to the number of samples of your - validation dataset divided by the batch size. - Optional for `Sequence`: if unspecified, will use - the `len(validation_data)` as a number of steps. - class_weight: Dictionary mapping class indices to a weight - for the class. - max_queue_size: Maximum size for the generator queue - workers: Maximum number of processes to spin up - use_multiprocessing: If True, use process based threading. - Note that because - this implementation relies on multiprocessing, - you should not pass - non picklable arguments to the generator - as they can't be passed - easily to children processes. - shuffle: Whether to shuffle the order of the batches at - the beginning of each epoch. Only used with instances - of `Sequence` (keras.utils.Sequence). - initial_epoch: Epoch at which to start training - (useful for resuming a previous training run) - **kwargs: support for legacy arguments. - - Returns: - A `History` object. - - Raises: - RuntimeError: if the model was never compiled. - ValueError: In case the generator yields - data in an invalid format. - - Example: - - ```python - def generate_arrays_from_file(path): - while 1: - f = open(path) - for line in f: - # create Numpy arrays of input data - # and labels, from each line in the file - x, y = process_line(line) - yield (x, y) - f.close() - - model.fit_generator(generate_arrays_from_file('/my_file.txt'), - steps_per_epoch=1000, epochs=10) - ``` - """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) - - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.fit_generator( - generator, - steps_per_epoch, - epochs, - verbose=verbose, - callbacks=callbacks, - validation_data=validation_data, - validation_steps=validation_steps, - class_weight=class_weight, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - shuffle=shuffle, - initial_epoch=initial_epoch) - - def evaluate_generator(self, - generator, - steps=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - **kwargs): - """Evaluates the model on a data generator. - - The generator should return the same kind of data - as accepted by `test_on_batch`. - - Arguments: - generator: Generator yielding tuples (inputs, targets) - or (inputs, targets, sample_weights) - steps: Total number of steps (batches of samples) - to yield from `generator` before stopping. - Optional for `Sequence`: if unspecified, will use - the `len(generator)` as a number of steps. - max_queue_size: maximum size for the generator queue - workers: maximum number of processes to spin up - use_multiprocessing: if True, use process based threading. - Note that because this implementation - relies on multiprocessing, you should not pass - non picklable arguments to the generator - as they can't be passed easily to children processes. - **kwargs: support for legacy arguments. - - Returns: - Scalar test loss (if the model has no metrics) - or list of scalars (if the model computes other metrics). - The attribute `model.metrics_names` will give you - the display labels for the scalar outputs. - - Raises: - RuntimeError: if the model was never compiled. - ValueError: In case the generator yields - data in an invalid format. - """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) - - if not self.built: - raise RuntimeError('The model needs to be compiled before being used.') - return self.model.evaluate_generator( - generator, - steps, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing) - - def predict_generator(self, - generator, - steps=None, - max_queue_size=10, - workers=1, - use_multiprocessing=False, - verbose=0, - **kwargs): - """Generates predictions for the input samples from a data generator. - - The generator should return the same kind of data as accepted by - `predict_on_batch`. - - Arguments: - generator: generator yielding batches of input samples. - steps: Total number of steps (batches of samples) - to yield from `generator` before stopping. - Optional for `Sequence`: if unspecified, will use - the `len(generator)` as a number of steps. - max_queue_size: maximum size for the generator queue - workers: maximum number of processes to spin up - use_multiprocessing: if True, use process based threading. - Note that because this implementation - relies on multiprocessing, you should not pass - non picklable arguments to the generator - as they can't be passed easily to children processes. - verbose: verbosity mode, 0 or 1. - **kwargs: support for legacy arguments. - - Returns: - A Numpy array of predictions. - - Raises: - ValueError: In case the generator yields - data in an invalid format. - """ - # Legacy support - if 'max_q_size' in kwargs: - max_queue_size = kwargs.pop('max_q_size') - logging.warning('The argument `max_q_size` has been renamed ' - '`max_queue_size`. Update your method calls accordingly.') - if 'pickle_safe' in kwargs: - use_multiprocessing = kwargs.pop('pickle_safe') - logging.warning('The argument `pickle_safe` has been renamed ' - '`use_multiprocessing`. ' - 'Update your method calls accordingly.') - if kwargs: - raise ValueError('Unrecognized keyword arguments: ' + str(kwargs)) - - if not self.built: - self.build() - return self.model.predict_generator( - generator, - steps, - max_queue_size=max_queue_size, - workers=workers, - use_multiprocessing=use_multiprocessing, - verbose=verbose) - def get_config(self): - config = [] - for layer in self.layers: - config.append({ - 'class_name': layer.__class__.__name__, - 'config': layer.get_config() - }) - return copy.deepcopy(config) - @classmethod - def from_config(cls, config, custom_objects=None): - model = cls() - for conf in config: - layer = layer_module.deserialize(conf, custom_objects=custom_objects) - model.add(layer) - return model +# API entries importable from `keras.models`: +Model = training.Model # pylint: disable=invalid-name +Sequential = sequential.Sequential # pylint: disable=invalid-name +save_model = saving.save_model +load_model = saving.load_model +model_from_config = saving.model_from_config +model_from_yaml = saving.model_from_yaml +model_from_json = saving.model_from_json def _clone_functional_model(model, input_tensors=None): @@ -1363,7 +90,7 @@ def _clone_functional_model(model, input_tensors=None): else: # Make sure that all input tensors come from a Keras layer. # If tensor comes from an input layer: cache the input layer. - input_tensors = topology._to_list(input_tensors) + input_tensors = generic_utils.to_list(input_tensors) input_tensors_ = [] for i, x in enumerate(input_tensors): if not K.is_keras_tensor(x): @@ -1400,7 +127,7 @@ def _clone_functional_model(model, input_tensors=None): # Reuse previously cloned layer. layer = layer_map[layer] # Don't call InputLayer multiple times. - if isinstance(layer, topology.InputLayer): + if isinstance(layer, InputLayer): continue # Gather inputs to call the new layer. @@ -1425,8 +152,9 @@ def _clone_functional_model(model, input_tensors=None): if has_arg(layer.call, 'mask'): if 'mask' not in kwargs: kwargs['mask'] = computed_mask - output_tensors = topology._to_list(layer(computed_tensor, **kwargs)) - output_masks = topology._to_list( + output_tensors = generic_utils.to_list(layer(computed_tensor, + **kwargs)) + output_masks = generic_utils.to_list( layer.compute_mask(computed_tensor, computed_mask)) computed_tensors = [computed_tensor] computed_masks = [computed_mask] @@ -1436,8 +164,9 @@ def _clone_functional_model(model, input_tensors=None): if has_arg(layer.call, 'mask'): if 'mask' not in kwargs: kwargs['mask'] = computed_masks - output_tensors = topology._to_list(layer(computed_tensors, **kwargs)) - output_masks = topology._to_list( + output_tensors = generic_utils.to_list(layer(computed_tensors, + **kwargs)) + output_masks = generic_utils.to_list( layer.compute_mask(computed_tensors, computed_masks)) # Update tensor_map. for x, y, mask in zip(reference_output_tensors, output_tensors, @@ -1487,14 +216,14 @@ def _clone_sequential_model(model, input_tensors=None): if input_tensors is None: return Sequential(layers=layers, name=model.name) else: - if len(topology._to_list(input_tensors)) != 1: + if len(generic_utils.to_list(input_tensors)) != 1: raise ValueError('To clone a `Sequential` model, we expect ' ' at most one tensor ' 'as part of `input_tensors`.') - x = topology._to_list(input_tensors)[0] + x = generic_utils.to_list(input_tensors)[0] if K.is_keras_tensor(x): origin_layer = x._keras_history[0] - if isinstance(origin_layer, topology.InputLayer): + if isinstance(origin_layer, InputLayer): return Sequential(layers=[origin_layer] + layers, name=model.name) else: raise ValueError('Cannot clone a `Sequential` model on top ' diff --git a/tensorflow/python/keras/_impl/keras/models_test.py b/tensorflow/python/keras/_impl/keras/models_test.py index 04017e4b28b27e52f88a7746fc44510c29edffce..5978ddd987c63b9d87a31be6837172f08512ef73 100644 --- a/tensorflow/python/keras/_impl/keras/models_test.py +++ b/tensorflow/python/keras/_impl/keras/models_test.py @@ -12,362 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for training routines.""" +"""Tests for `models.py` (model cloning, mainly).""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os -import shutil -import tempfile - import numpy as np from tensorflow.python.keras._impl import keras from tensorflow.python.platform import test -from tensorflow.python.training import training as training_module - -try: - import h5py # pylint:disable=g-import-not-at-top -except ImportError: - h5py = None - - -class TestModelSaving(test.TestCase): - - def test_sequential_model_saving(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.RepeatVector(3)) - model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy], - sample_weight_mode='temporal') - x = np.random.random((1, 3)) - y = np.random.random((1, 3, 3)) - model.train_on_batch(x, y) - - out = model.predict(x) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - new_model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - out2 = new_model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - # test that new updates are the same with both models - x = np.random.random((1, 3)) - y = np.random.random((1, 3, 3)) - model.train_on_batch(x, y) - new_model.train_on_batch(x, y) - out = model.predict(x) - out2 = new_model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - def test_sequential_model_saving_2(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - # test with custom optimizer, loss - - class CustomOp(keras.optimizers.RMSprop): - pass - - def custom_loss(y_true, y_pred): - return keras.losses.mse(y_true, y_pred) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss=custom_loss, optimizer=CustomOp(), metrics=['acc']) - - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - model.train_on_batch(x, y) - - out = model.predict(x) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - model = keras.models.load_model( - fname, - custom_objects={'CustomOp': CustomOp, - 'custom_loss': custom_loss}) - os.close(fd) - os.remove(fname) - - out2 = model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - def test_functional_model_saving(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - inputs = keras.layers.Input(shape=(3,)) - x = keras.layers.Dense(2)(inputs) - output = keras.layers.Dense(3)(x) - - model = keras.models.Model(inputs, output) - model.compile(loss=keras.losses.MSE, - optimizer=keras.optimizers.RMSprop(lr=0.0001), - metrics=[keras.metrics.categorical_accuracy]) - x = np.random.random((1, 3)) - y = np.random.random((1, 3)) - model.train_on_batch(x, y) - - out = model.predict(x) - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - out2 = model.predict(x) - self.assertAllClose(out, out2, atol=1e-05) - - def test_saving_without_compilation(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - def test_saving_with_tf_optimizer(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss='mse', - optimizer=training_module.AdadeltaOptimizer(0.1), - metrics=['acc']) - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - def test_saving_right_after_compilation(self): - if h5py is None: - return # Skip test if models cannot be saved. - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(2, input_shape=(3,))) - model.add(keras.layers.Dense(3)) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - model.model._make_train_function() - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - def test_saving_lambda_numpy_array_arguments(self): - if h5py is None: - return # Skip test if models cannot be saved. - - mean = np.random.random((4, 2, 3)) - std = np.abs(np.random.random((4, 2, 3))) + 1e-5 - inputs = keras.layers.Input(shape=(4, 2, 3)) - output = keras.layers.Lambda(lambda image, mu, std: (image - mu) / std, - arguments={'mu': mean, 'std': std})(inputs) - model = keras.models.Model(inputs, output) - model.compile(loss='mse', optimizer='sgd', metrics=['acc']) - - fd, fname = tempfile.mkstemp('.h5') - keras.models.save_model(model, fname) - - model = keras.models.load_model(fname) - os.close(fd) - os.remove(fname) - - self.assertAllClose(mean, model.layers[1].arguments['mu']) - self.assertAllClose(std, model.layers[1].arguments['std']) - - -class TestSequential(test.TestCase): - """Most Sequential model API tests are covered in `training_test.py`. - """ - - def test_basic_methods(self): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_dim=2)) - model.add(keras.layers.Dropout(0.3, name='dp')) - model.add(keras.layers.Dense(2, kernel_regularizer='l2', - kernel_constraint='max_norm')) - model.build() - self.assertEqual(model.state_updates, model.model.state_updates) - self.assertEqual(model.get_layer(name='dp').name, 'dp') - - def test_sequential_pop(self): - num_hidden = 5 - input_dim = 3 - batch_size = 5 - num_classes = 2 - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) - model.compile(loss='mse', optimizer='sgd') - x = np.random.random((batch_size, input_dim)) - y = np.random.random((batch_size, num_classes)) - model.fit(x, y, epochs=1) - model.pop() - self.assertEqual(len(model.layers), 1) - self.assertEqual(model.output_shape, (None, num_hidden)) - model.compile(loss='mse', optimizer='sgd') - y = np.random.random((batch_size, num_hidden)) - model.fit(x, y, epochs=1) - - # Test popping single-layer model - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.pop() - self.assertEqual(len(model.layers), 0) - self.assertEqual(len(model.outputs), 0) - - # Invalid use case - model = keras.models.Sequential() - with self.assertRaises(TypeError): - model.pop() - - def test_sequential_weight_loading(self): - if h5py is None: - return - - temp_dir = self.get_temp_dir() - self.addCleanup(shutil.rmtree, temp_dir) - h5_path = os.path.join(temp_dir, 'test.h5') - - num_hidden = 5 - input_dim = 3 - batch_size = 5 - num_classes = 2 - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) - - x = np.random.random((batch_size, input_dim)) - ref_y = model.predict(x) - - model.save_weights(h5_path) - - model = keras.models.Sequential() - model.add(keras.layers.Dense(num_hidden, input_dim=input_dim)) - model.add(keras.layers.Dense(num_classes)) - model.load_weights(h5_path) - y = model.predict(x) - - self.assertAllClose(y, ref_y) - - def test_invalid_use_cases(self): - with self.test_session(): - # Added objects must be layer instances - with self.assertRaises(TypeError): - model = keras.models.Sequential() - model.add(None) - - # Added layers must have an inputs shape - with self.assertRaises(ValueError): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1)) - - # Added layers cannot have multiple outputs - class MyLayer(keras.layers.Layer): - - def call(self, inputs): - return [3 * inputs, 2 * inputs] - - def compute_output_shape(self, input_shape): - return [input_shape, input_shape] - - with self.assertRaises(ValueError): - model = keras.models.Sequential() - model.add(MyLayer(input_shape=(3,))) - with self.assertRaises(TypeError): - model = keras.models.Sequential() - model.add(keras.layers.Dense(1, input_dim=1)) - model.add(MyLayer()) - - # Building empty model - model = keras.models.Sequential() - with self.assertRaises(TypeError): - model.build() - - def test_nested_sequential_trainability(self): - input_dim = 20 - num_units = 10 - num_classes = 2 - - inner_model = keras.models.Sequential() - inner_model.add(keras.layers.Dense(num_units, input_shape=(input_dim,))) - - model = keras.models.Sequential() - model.add(inner_model) - model.add(keras.layers.Dense(num_classes)) - - self.assertEqual(len(model.trainable_weights), 4) - inner_model.trainable = False - self.assertEqual(len(model.trainable_weights), 2) - inner_model.trainable = True - self.assertEqual(len(model.trainable_weights), 4) - - def test_sequential_update_disabling(self): - val_a = np.random.random((10, 4)) - val_out = np.random.random((10, 4)) - - with self.test_session(): - model = keras.models.Sequential() - model.add(keras.layers.BatchNormalization(input_shape=(4,))) - - model.trainable = False - assert not model.updates - - model.compile('sgd', 'mse') - assert not model.updates - assert not model.model.updates - - x1 = model.predict(val_a) - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - self.assertAllClose(x1, x2, atol=1e-7) - - model.trainable = True - model.compile('sgd', 'mse') - assert model.updates - assert model.model.updates - - model.train_on_batch(val_a, val_out) - x2 = model.predict(val_a) - assert np.abs(np.sum(x1 - x2)) > 1e-5 class TestModelCloning(test.TestCase): diff --git a/tensorflow/python/keras/_impl/keras/optimizers.py b/tensorflow/python/keras/_impl/keras/optimizers.py index 76a97156ed7d9ca89b0d94f31bed3a23eca9609d..6520128c5b65451aef20ec9626245fba5ef29927 100644 --- a/tensorflow/python/keras/_impl/keras/optimizers.py +++ b/tensorflow/python/keras/_impl/keras/optimizers.py @@ -704,8 +704,10 @@ class TFOptimizer(Optimizer): return self.optimizer.compute_gradients(loss, params) def get_updates(self, loss, params): - grads = self.optimizer.compute_gradients(loss, params) self.updates = [K.update_add(self.iterations, 1)] + if not params: + return self.updates + grads = self.optimizer.compute_gradients(loss, params) opt_update = self.optimizer.apply_gradients( grads, global_step=self.iterations) self.updates.append(opt_update) diff --git a/tensorflow/python/keras/_impl/keras/testing_utils.py b/tensorflow/python/keras/_impl/keras/testing_utils.py index b889e311b37d48732641205a90ca83af34ea4489..60799ee1e038b4466351248bb5de7c8fc0de02a2 100644 --- a/tensorflow/python/keras/_impl/keras/testing_utils.py +++ b/tensorflow/python/keras/_impl/keras/testing_utils.py @@ -22,6 +22,7 @@ import numpy as np from tensorflow.python.framework import tensor_shape from tensorflow.python.keras._impl import keras +from tensorflow.python.training.rmsprop import RMSPropOptimizer from tensorflow.python.util import tf_inspect @@ -105,8 +106,14 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, # test in functional API x = keras.layers.Input(shape=input_shape[1:], dtype=input_dtype) y = layer(x) - assert keras.backend.dtype(y) == expected_output_dtype - + if keras.backend.dtype(y) != expected_output_dtype: + raise AssertionError('When testing layer %s, for input %s, found output ' + 'dtype=%s but expected to find %s.\nFull kwargs: %s' % + (layer_cls.__name__, + x, + keras.backend.dtype(y), + expected_output_dtype, + kwargs)) # check shape inference model = keras.models.Model(x, y) expected_output_shape = tuple( @@ -117,7 +124,15 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, for expected_dim, actual_dim in zip(expected_output_shape, actual_output_shape): if expected_dim is not None: - assert expected_dim == actual_dim + if expected_dim != actual_dim: + raise AssertionError( + 'When testing layer %s, for input %s, found output_shape=' + '%s but expected to find %s.\nFull kwargs: %s' % + (layer_cls.__name__, + x, + actual_output_shape, + expected_output_shape, + kwargs)) if expected_output is not None: np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3) @@ -131,7 +146,7 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, np.testing.assert_allclose(output, actual_output, rtol=1e-3) # test training mode (e.g. useful for dropout tests) - model.compile('rmsprop', 'mse') + model.compile(RMSPropOptimizer(0.01), 'mse') model.train_on_batch(input_data, actual_output) # test as first layer in Sequential API @@ -146,7 +161,15 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, for expected_dim, actual_dim in zip(expected_output_shape, actual_output_shape): if expected_dim is not None: - assert expected_dim == actual_dim + if expected_dim != actual_dim: + raise AssertionError( + 'When testing layer %s, for input %s, found output_shape=' + '%s but expected to find %s.\nFull kwargs: %s' % + (layer_cls.__name__, + x, + actual_output_shape, + expected_output_shape, + kwargs)) if expected_output is not None: np.testing.assert_allclose(actual_output, expected_output, rtol=1e-3) @@ -159,9 +182,5 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None, output = recovered_model.predict(input_data) np.testing.assert_allclose(output, actual_output, rtol=1e-3) - # test training mode (e.g. useful for dropout tests) - model.compile('rmsprop', 'mse') - model.train_on_batch(input_data, actual_output) - # for further checks in the caller function return actual_output diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py index adbe6c3288a3eabb858e78267577ddff6d798972..5196bf17400c33d876daa430a9d3d5b4f4b491a1 100644 --- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py @@ -291,55 +291,73 @@ class Progbar(object): Arguments: target: Total number of steps expected, None if unknown. + width: Progress bar width on screen. + verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose) + stateful_metrics: Iterable of string names of metrics that + should *not* be averaged over time. Metrics in this list + will be displayed as-is. All others will be averaged + by the progbar before display. interval: Minimum visual progress update interval (in seconds). """ - def __init__(self, target, width=30, verbose=1, interval=0.05): - self.width = width - if target is None: - target = -1 + def __init__(self, target, width=30, verbose=1, interval=0.05, + stateful_metrics=None): self.target = target - self.sum_values = {} - self.unique_values = [] - self.start = time.time() - self.last_update = 0 - self.interval = interval - self.total_width = 0 - self.seen_so_far = 0 + self.width = width self.verbose = verbose + self.interval = interval + if stateful_metrics: + self.stateful_metrics = set(stateful_metrics) + else: + self.stateful_metrics = set() + self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and sys.stdout.isatty()) or - 'ipykernel' in sys.modules) - - def update(self, current, values=None, force=False): + 'ipykernel' in sys.modules or + 'posix' in sys.modules) + self._total_width = 0 + self._seen_so_far = 0 + # We use a dict + list to avoid garbage collection + # issues found in OrderedDict + self._values = {} + self._values_order = [] + self._start = time.time() + self._last_update = 0 + + def update(self, current, values=None): """Updates the progress bar. Arguments: current: Index of current step. - values: List of tuples (name, value_for_last_step). - The progress bar will display averages for these values. - force: Whether to force visual progress update. + values: List of tuples: + `(name, value_for_last_step)`. + If `name` is in `stateful_metrics`, + `value_for_last_step` will be displayed as-is. + Else, an average of the metric over time will be displayed. """ values = values or [] for k, v in values: - if k not in self.sum_values: - self.sum_values[k] = [ - v * (current - self.seen_so_far), current - self.seen_so_far - ] - self.unique_values.append(k) + if k not in self._values_order: + self._values_order.append(k) + if k not in self.stateful_metrics: + if k not in self._values: + self._values[k] = [v * (current - self._seen_so_far), + current - self._seen_so_far] + else: + self._values[k][0] += v * (current - self._seen_so_far) + self._values[k][1] += (current - self._seen_so_far) else: - self.sum_values[k][0] += v * (current - self.seen_so_far) - self.sum_values[k][1] += (current - self.seen_so_far) - self.seen_so_far = current + self._values[k] = v + self._seen_so_far = current now = time.time() - info = ' - %.0fs' % (now - self.start) + info = ' - %.0fs' % (now - self._start) if self.verbose == 1: - if (not force and (now - self.last_update) < self.interval and - current < self.target): + if (now - self._last_update < self.interval and + self.target is not None and current < self.target): return - prev_total_width = self.total_width + prev_total_width = self._total_width if self._dynamic_display: sys.stdout.write('\b' * prev_total_width) sys.stdout.write('\r') @@ -360,22 +378,21 @@ class Progbar(object): bar += '=' bar += ('.' * (self.width - prog_width)) bar += ']' - sys.stdout.write(bar) - self.total_width = len(bar) else: bar = '%7d/Unknown' % current - self.total_width = len(bar) + self._total_width = len(bar) sys.stdout.write(bar) if current: - time_per_unit = (now - self.start) / current + time_per_unit = (now - self._start) / current else: time_per_unit = 0 if self.target is not None and current < self.target: eta = time_per_unit * (self.target - current) if eta > 3600: - eta_format = '%d:%02d:%02d' % (eta // 3600, (eta % 3600) // 60, + eta_format = '%d:%02d:%02d' % (eta // 3600, + (eta % 3600) // 60, eta % 60) elif eta > 60: eta_format = '%d:%02d' % (eta // 60, eta % 60) @@ -391,35 +408,32 @@ class Progbar(object): else: info += ' %.0fus/step' % (time_per_unit * 1e6) - for k in self.unique_values: + for k in self._values_order: info += ' - %s:' % k - if isinstance(self.sum_values[k], list): - avg = np.mean(self.sum_values[k][0] / max(1, self.sum_values[k][1])) + if isinstance(self._values[k], list): + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if abs(avg) > 1e-3: info += ' %.4f' % avg else: info += ' %.4e' % avg else: - info += ' %s' % self.sum_values[k] + info += ' %s' % self._values[k] + + self._total_width += len(info) + if prev_total_width > self._total_width: + info += (' ' * (prev_total_width - self._total_width)) - self.total_width += len(info) - if prev_total_width > self.total_width: - info += (' ' * (prev_total_width - self.total_width)) if self.target is not None and current >= self.target: info += '\n' sys.stdout.write(info) sys.stdout.flush() - if current >= self.target: - sys.stdout.write('\n') - elif self.verbose == 2: if self.target is None or current >= self.target: - for k in self.unique_values: + for k in self._values_order: info += ' - %s:' % k - avg = np.mean( - self.sum_values[k][0] / max(1, self.sum_values[k][1])) + avg = np.mean(self._values[k][0] / max(1, self._values[k][1])) if avg > 1e-3: info += ' %.4f' % avg else: @@ -429,7 +443,86 @@ class Progbar(object): sys.stdout.write(info) sys.stdout.flush() - self.last_update = now + self._last_update = now def add(self, n, values=None): - self.update(self.seen_so_far + n, values) + self.update(self._seen_so_far + n, values) + + +def make_batches(size, batch_size): + """Returns a list of batch indices (tuples of indices). + + Arguments: + size: Integer, total size of the data to slice into batches. + batch_size: Integer, batch size. + + Returns: + A list of tuples of array indices. + """ + num_batches = int(np.ceil(size / float(batch_size))) + return [(i * batch_size, min(size, (i + 1) * batch_size)) + for i in range(0, num_batches)] + + +def slice_arrays(arrays, start=None, stop=None): + """Slice an array or list of arrays. + + This takes an array-like, or a list of + array-likes, and outputs: + - arrays[start:stop] if `arrays` is an array-like + - [x[start:stop] for x in arrays] if `arrays` is a list + + Can also work on list/array of indices: `slice_arrays(x, indices)` + + Arguments: + arrays: Single array or list of arrays. + start: can be an integer index (start index) + or a list/array of indices + stop: integer (stop index); should be None if + `start` was a list. + + Returns: + A slice of the array(s). + + Raises: + ValueError: If the value of start is a list and stop is not None. + """ + if arrays is None: + return [None] + if isinstance(start, list) and stop is not None: + raise ValueError('The stop argument has to be None if the value of start is' + 'a list.') + elif isinstance(arrays, list): + if hasattr(start, '__len__'): + # hdf5 datasets only support list objects as indices + if hasattr(start, 'shape'): + start = start.tolist() + return [None if x is None else x[start] for x in arrays] + else: + return [None if x is None else x[start:stop] for x in arrays] + else: + if hasattr(start, '__len__'): + if hasattr(start, 'shape'): + start = start.tolist() + return arrays[start] + elif hasattr(start, '__getitem__'): + return arrays[start:stop] + else: + return [None] + + +def to_list(x): + """Normalizes a list/tensor into a list. + + If a tensor is passed, we return + a list of size 1 containing the tensor. + + Arguments: + x: target object to be normalized. + + Returns: + A list. + """ + if isinstance(x, list): + return x + return [x] diff --git a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py index a9c8fa68c9a0412befc82b1cc32e11dfcd49cebb..4c8009dfd80e1aec457fa03687f2840c7fe4607b 100644 --- a/tensorflow/python/keras/_impl/keras/utils/layer_utils.py +++ b/tensorflow/python/keras/_impl/keras/utils/layer_utils.py @@ -59,6 +59,10 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): if model.__class__.__name__ == 'Sequential': sequential_like = True + elif not model._is_graph_network: + # We treat subclassed models as a simple sequence of layers, for logging + # purposes. + sequential_like = True else: sequential_like = True nodes_by_depth = model._nodes_by_depth.values() @@ -118,17 +122,24 @@ def print_summary(model, line_length=None, positions=None, print_fn=None): print_fn('=' * line_length) def print_layer_summary(layer): + """Prints a summary for a single layer. + + Arguments: + layer: target layer. + """ try: output_shape = layer.output_shape except AttributeError: output_shape = 'multiple' + except RuntimeError: # output_shape unknown in Eager mode. + output_shape = '?' name = layer.name cls_name = layer.__class__.__name__ fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] print_row(fields, positions) def print_layer_summary_with_connections(layer): - """Prints a summary for a single layer. + """Prints a summary for a single layer (including topological connections). Arguments: layer: target layer. diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index d4ceb2e489c8a20d26eaf9d89b12992d2b8673d7..c9aa4a252dadf9b8d6b1a4fea50cce3cec57265a 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2892,6 +2892,40 @@ tf_py_test( ], ) +tf_py_test( + name = "accumulate_n_test", + size = "small", + srcs = ["accumulate_n_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +tf_py_test( + name = "accumulate_n_eager_test", + size = "small", + srcs = ["accumulate_n_eager_test.py"], + additional_deps = [ + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py similarity index 72% rename from tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py rename to tensorflow/python/kernel_tests/accumulate_n_eager_test.py index 35974b9e21d2d7423777a95a99f51c9cb4b453b2..dc11b7deceb9040584aca1f629f4d003aef39428 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_eager_test.py @@ -12,48 +12,41 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into -`ops.math_ops`. - -These test cases spefically exercise the `eager` APIs. They need to be in a -separate file from the remaining tests because eager mode is currently something -you can turn on but can't turn off for the lifetime of the current process.""" +"""Tests for new version of accumulate_n op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 - from tensorflow.python.eager import backprop from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test - class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): - """Tests of the new, differentiable version of accumulate_n""" + """Tests of the new, differentiable version of accumulate_n.""" def testMinimalEagerMode(self): forty = constant_op.constant(40) two = constant_op.constant(2) - answer = av2.accumulate_n_v2([forty, two]) + answer = math_ops.accumulate_n([forty, two]) self.assertEqual(42, answer.numpy()) - def testFloat(self): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).numpy()) - self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).numpy()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).numpy()) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5).numpy()) def testGrad(self): np.random.seed(42) @@ -65,16 +58,14 @@ class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): ] def fn(first, second, third): - return av2.accumulate_n_v2([first, second, third]) + return math_ops.accumulate_n([first, second, third]) grad_fn = backprop.gradients_function(fn) grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) - self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 [elem.numpy() for elem in grad]) - if __name__ == "__main__": ops.enable_eager_execution() test.main() - diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/python/kernel_tests/accumulate_n_test.py similarity index 79% rename from tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py rename to tensorflow/python/kernel_tests/accumulate_n_test.py index 45962098e93acfac414396ddbeaa847701ff2b4b..0a6d4aea370eb788de0c65b4758a3210a7d2944d 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/python/kernel_tests/accumulate_n_test.py @@ -12,42 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for new version of accumulate_n op that will eventually go into -`ops.math_ops`.""" +"""Tests for new version of accumulate_n op.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np -from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 - from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest class AccumulateNV2Test(test_util.TensorFlowTestCase): - """Tests of the new, differentiable version of accumulate_n""" + """Tests of the new, differentiable version of accumulate_n.""" def testFloat(self): np.random.seed(12345) x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).eval()) - self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).eval()) + self.assertAllClose(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllClose(x[0] * 5, + math_ops.accumulate_n([tf_x[0]] * 5).eval()) def testInt(self): np.random.seed(54321) x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] tf_x = ops.convert_n_to_tensor(x) with self.test_session(use_gpu=True): - self.assertAllEqual(sum(x), av2.accumulate_n_v2(tf_x).eval()) - self.assertAllEqual(x[0] * 6, av2.accumulate_n_v2([tf_x[0]] * 6).eval()) + self.assertAllEqual(sum(x), math_ops.accumulate_n(tf_x).eval()) + self.assertAllEqual(x[0] * 6, + math_ops.accumulate_n([tf_x[0]] * 6).eval()) def testGrad(self): np.random.seed(42) @@ -55,9 +55,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.test_session(use_gpu=True) as sess: input_vars = [ variables.Variable(10.0 * np.random.random()) - for i in range(0, num_inputs) + for _ in range(0, num_inputs) ] - accum_n = av2.accumulate_n_v2(input_vars) + accum_n = math_ops.accumulate_n(input_vars) sess.run(variables.global_variables_initializer()) accum_n_grad = gradients.gradients(accum_n, input_vars) self.assertAllEqual( @@ -77,7 +77,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): ops.convert_to_tensor(x, dtype=dtypes_lib.float32) for x in random_arrays ] - tf_val = av2.accumulate_n_v2(random_tensors) + tf_val = math_ops.accumulate_n(random_tensors) np_val = random_arrays[0] for random_array in random_arrays[1:]: np_val += random_array @@ -86,7 +86,7 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): def testZeroArgs(self): with self.test_session(): with self.assertRaises(ValueError): - tf_val = av2.accumulate_n_v2([]) + tf_val = math_ops.accumulate_n([]) tf_val.eval() def testWrongShape(self): @@ -94,28 +94,28 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(0.2) b = variables.Variable(0.1) - tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[] + math_ops.accumulate_n([a, b], shape=[2, 2]) # Should be shape=[] def testIncompatibleShapes(self): with self.test_session(): with self.assertRaises(ValueError): a = variables.Variable(np.array([0.1, 0.2])) b = variables.Variable(np.array([[0.3], [0.4]])) - tf_val = av2.accumulate_n_v2([a, b]) + math_ops.accumulate_n([a, b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32) + math_ops.accumulate_n([a, b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + math_ops.accumulate_n([a], tensor_dtype=np.int32) if __name__ == "__main__": diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py index 1e2ea829884f2f97ab2203b54228365d85a9dea0..365cf72108de5a1e5e1eb47891a6ad64151add22 100644 --- a/tensorflow/python/kernel_tests/array_ops_test.py +++ b/tensorflow/python/kernel_tests/array_ops_test.py @@ -498,7 +498,7 @@ class StridedSliceTest(test_util.TensorFlowTestCase): def test_basic_slice(self): for tensor_type in STRIDED_SLICE_TYPES: - with self.test_session(use_gpu=True): + with self.test_session(use_gpu=not tensor_type.is_integer): checker = StridedSliceChecker( self, StridedSliceChecker.REF_TENSOR, tensor_type=tensor_type) _ = checker[:, :, :] @@ -884,7 +884,8 @@ class StridedSliceAssignChecker(object): if self.tensor_type.is_complex: value -= 1j * value - with self.test.test_session(use_gpu=True) as sess: + with self.test.test_session( + use_gpu=not self.tensor_type.is_integer) as sess: if self._use_resource: var = resource_variable_ops.ResourceVariable(self.x) else: @@ -974,9 +975,7 @@ class SliceAssignTest(test_util.TensorFlowTestCase): errors.InvalidArgumentError, "l-value dtype int32 does not match r-value dtype int64"): sess.run(v[:].assign(too_large_val)) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "l-value dtype int32 does not match r-value dtype int8"): + with self.assertRaises(errors.InvalidArgumentError): sess.run(v[:].assign(too_small_val)) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 15ff0ec09b65a8ba242473fb7b25ee00424e0926..58f38650eb526e98edf35b2425e0e9e1296ab353 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1840,6 +1840,23 @@ class ControlFlowTest(test.TestCase): [tensor_shape.unknown_shape()]) self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0})) + def testCondGradInNestedWhiles(self): + def outer_body(i, x): + _, x = control_flow_ops.while_loop( + lambda j, x: j < 3, inner_body, [0, 0.0]) + return i + 1, x + + def inner_body(j, x): + y = control_flow_ops.cond(math_ops.less(x, 1), lambda: 2 * x, lambda: x) + return j + 1, gradients_impl.gradients(y, x)[0] + + i, x = control_flow_ops.while_loop(lambda i, x: i < 3, outer_body, [0, 0.0]) + + with self.test_session() as sess: + i_val, x_val = sess.run([i, x]) + self.assertEqual(i_val, 3) + self.assertAllClose(x_val, 1.0) + def testWhile_NestedInput(self): with self.test_session() as sess: named = collections.namedtuple("named", ("a", "b")) diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py index edfb20d6a2b80cec930ddf696e8f0f69623a4de7..f4fe01f868da25660171c614bbf84390aead3ade 100644 --- a/tensorflow/python/kernel_tests/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_test.py @@ -302,25 +302,20 @@ class Conv2DTest(test.TestCase): padding, dilations): expected_results = [] computed_results = [] - default_dilations = (dilations[0] == 1 and dilations[1] == 1) for data_format, use_gpu in GetTestConfigs(): - # If any dilation rate is larger than 1, only do test on the GPU - # because we currently do not have a CPU implementation for arbitrary - # dilation rates. - if default_dilations or use_gpu: - expected, computed = self._ComputeReferenceDilatedConv( - tensor_in_sizes, filter_in_sizes, strides, dilations, padding, - data_format, use_gpu) - expected_results.append(expected) - computed_results.append(computed) - tolerance = 1e-2 if use_gpu else 1e-5 - expected_values = self.evaluate(expected_results) - computed_values = self.evaluate(computed_results) - for e_value, c_value in zip(expected_values, computed_values): - print("expected = ", e_value) - print("actual = ", c_value) - self.assertAllClose( - e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) + expected, computed = self._ComputeReferenceDilatedConv( + tensor_in_sizes, filter_in_sizes, strides, dilations, padding, + data_format, use_gpu) + expected_results.append(expected) + computed_results.append(computed) + tolerance = 1e-2 if use_gpu else 1e-5 + expected_values = self.evaluate(expected_results) + computed_values = self.evaluate(computed_results) + for e_value, c_value in zip(expected_values, computed_values): + print("expected = ", e_value) + print("actual = ", c_value) + self.assertAllClose( + e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-4) def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding, expected): @@ -365,13 +360,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter2x1Dilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 4, 4, 1], + filter_in_sizes=[2, 2, 1, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2DEmpty(self): @@ -385,13 +379,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2DEmptyDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[0, 2, 3, 3], - filter_in_sizes=[1, 1, 3, 3], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[0, 2, 3, 3], + filter_in_sizes=[1, 1, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D2x2Filter(self): @@ -406,13 +399,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[2, 2, 3, 3], - strides=[1, 1], - dilations=[1, 2], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[2, 2, 3, 3], + strides=[1, 1], + dilations=[1, 2], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D1x2Filter(self): @@ -430,13 +422,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2D1x2FilterDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 2, 3, 3], - filter_in_sizes=[1, 2, 3, 3], - strides=[1, 1], - dilations=[2, 1], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 2, 3, 3], + filter_in_sizes=[1, 2, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID") @test_util.run_in_graph_and_eager_modes() def testConv2D2x2FilterStride2(self): @@ -512,13 +503,12 @@ class Conv2DTest(test.TestCase): @test_util.run_in_graph_and_eager_modes() def testConv2DKernelSizeMatchesInputSizeDilation(self): - if test.is_gpu_available(cuda_only=True): - self._VerifyDilatedConvValues( - tensor_in_sizes=[1, 3, 3, 1], - filter_in_sizes=[2, 2, 1, 2], - strides=[1, 1], - dilations=[2, 2], - padding="VALID") + self._VerifyDilatedConvValues( + tensor_in_sizes=[1, 3, 3, 1], + filter_in_sizes=[2, 2, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID") # TODO(yzhwang): this currently fails. # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1], @@ -1523,36 +1513,6 @@ class Conv2DTest(test.TestCase): strides=[1, 1, 1, 1], padding="VALID")) - def testCPUConv2DNCHWUnimplemented(self): - with self.test_session(use_gpu=False): - with self.assertRaisesRegexp(errors_impl.UnimplementedError, - "NHWC tensor format for now"): - conv = self._SetupValuesForDevice( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - dilations=[1, 1], - strides=[1, 1], - padding="VALID", - data_format="NCHW", - dtype=dtypes.float32, - use_gpu=False) - self.evaluate(conv) - - def testCPUConv2DDilatedUnimplemented(self): - with self.test_session(use_gpu=False): - with self.assertRaisesRegexp(errors_impl.UnimplementedError, - "dilated rate of 1 for now"): - conv = self._SetupValuesForDevice( - tensor_in_sizes=[1, 4, 4, 1], - filter_in_sizes=[2, 2, 1, 1], - dilations=[2, 1], - strides=[1, 1], - padding="VALID", - data_format="NHWC", - dtype=dtypes.float32, - use_gpu=False) - self.evaluate(conv) - class DepthwiseConv2DTest(test.TestCase): @@ -1887,7 +1847,7 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding, def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding): def Test(self): - if test.is_gpu_available(cuda_only=True) and stride == 1: + if stride == 1: tf_logging.info("Testing InceptionFwd with dilations %s", (input_size, filter_size, stride, padding)) self._VerifyDilatedConvValues( diff --git a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py index fedbf9e696923a34968e7a907e4099c520d1447b..5e8937ad2c36afb2b1ddb58ffb238a45e09e4b30 100644 --- a/tensorflow/python/kernel_tests/dynamic_partition_op_test.py +++ b/tensorflow/python/kernel_tests/dynamic_partition_op_test.py @@ -326,6 +326,18 @@ class DynamicPartitionTest(test.TestCase): with self.assertRaises(ValueError): data_flow_ops.dynamic_partition(data, indices, num_partitions=4) + # see https://github.com/tensorflow/tensorflow/issues/17106 + def testCUBBug(self): + x = constant_op.constant(np.random.randn(3072)) + inds = [0]*189 + [1]*184 + [2]*184 + [3]*191 + [4]*192 + [5]*195 + [6]*195 + inds += [7]*195 + [8]*188 + [9]*195 + [10]*188 + [11]*202 + [12]*194 + inds += [13]*194 + [14]*194 + [15]*192 + self.assertEqual(len(inds), x.shape[0]) + partitioned = data_flow_ops.dynamic_partition(x, inds, 16) + with self.test_session() as sess: + res = sess.run(partitioned) + self.assertEqual(res[-1].shape[0], 192) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py index 343d158498833dd92361bc41d433e28296fc4c9a..8cb9f9e6213cda8daae7b629fc31d4721fd48fa7 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py @@ -129,7 +129,7 @@ class LinearOperatorDiagTest( with self.test_session() as sess: x = random_ops.random_normal(shape=(2, 2, 3, 4)) - # This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve + # This LinearOperatorDiag will be broadcast to (2, 2, 3, 3) during solve # and matmul with 'x' as the argument. diag = random_ops.random_uniform(shape=(2, 1, 3)) operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) diff --git a/tensorflow/python/kernel_tests/losses_test.py b/tensorflow/python/kernel_tests/losses_test.py index 197dbf44afaea2cfaf5a1ffebb6ac0a6be09d165..1123c20a165ba93bd380fa471a8be91f7005d7bb 100644 --- a/tensorflow/python/kernel_tests/losses_test.py +++ b/tensorflow/python/kernel_tests/losses_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl @@ -32,11 +33,25 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses +from tensorflow.python.ops.losses import losses_impl from tensorflow.python.ops.losses import util from tensorflow.python.platform import test from tensorflow.python.training import momentum as momentum_lib +safe_div = losses_impl._safe_div # pylint: disable=protected-access + + +class SafeDivTest(test.TestCase): + + def testEager(self): + with context.eager_mode(): + self.assertAllEqual(safe_div(constant_op.constant(1.0), + constant_op.constant(0.0)), 0.0) + self.assertAllEqual(safe_div(constant_op.constant(1.0), + 0.0), 0.0) + + class AbsoluteDifferenceLossTest(test.TestCase): def setUp(self): diff --git a/tensorflow/python/kernel_tests/metrics_test.py b/tensorflow/python/kernel_tests/metrics_test.py index e0e752147cdf8690d22fa782aca2561b2935fa8e..59e7afa2dcb1e02ed9c66e5cf75753f96552b4e0 100644 --- a/tensorflow/python/kernel_tests/metrics_test.py +++ b/tensorflow/python/kernel_tests/metrics_test.py @@ -417,7 +417,7 @@ class MeanTensorTest(test.TestCase): self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5) - def testWeighted1d(self): + def testBinaryWeighted1d(self): with self.test_session() as sess: # Create the queue that populates the values. values_queue = data_flow_ops.FIFOQueue( @@ -444,6 +444,33 @@ class MeanTensorTest(test.TestCase): sess.run(update_op) self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5) + def testWeighted1d(self): + with self.test_session() as sess: + # Create the queue that populates the values. + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + # Create the queue that populates the weights. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) + _enqueue_vector(sess, weights_queue, [[0.0025]]) + _enqueue_vector(sess, weights_queue, [[0.005]]) + _enqueue_vector(sess, weights_queue, [[0.01]]) + _enqueue_vector(sess, weights_queue, [[0.0075]]) + weights = weights_queue.dequeue() + + mean, update_op = metrics.mean_tensor(values, weights) + + sess.run(variables.local_variables_initializer()) + for _ in range(4): + sess.run(update_op) + self.assertAllClose([[0.8, 3.52]], sess.run(mean), 5) + def testWeighted2d_1(self): with self.test_session() as sess: # Create the queue that populates the values. @@ -1105,9 +1132,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.79166, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.79166, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3) def testAnotherAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1119,9 +1146,9 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.610317, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) - self.assertAlmostEqual(0.610317, auc.eval(), delta=1e-3) + self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) def testThirdAUCPRSpecialCase(self): with self.test_session() as sess: @@ -1133,9 +1160,26 @@ class AUCTest(test.TestCase): auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(0.90277, sess.run(update_op), delta=1e-3) + self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) + + self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) - self.assertAlmostEqual(0.90277, auc.eval(), delta=1e-3) + def testFourthAUCPRSpecialCase(self): + # Create the labels and data. + labels = np.array([ + 0, 0, 0, 0, 0, 0, 0, 1, 0, 1]) + predictions = np.array([ + 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35]) + + with self.test_session() as sess: + auc, _ = metrics.auc( + labels, predictions, curve='PR', num_thresholds=11) + + sess.run(variables.local_variables_initializer()) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(0.0, auc.eval(), delta=0.001) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -1161,16 +1205,16 @@ class AUCTest(test.TestCase): self.assertAlmostEqual(1, auc.eval(), 6) - def testRecallOneAndPrecisionOneGivesOnePRAUC(self): + def testRecallOneAndPrecisionOne(self): with self.test_session() as sess: predictions = array_ops.ones([4], dtype=dtypes_lib.float32) labels = array_ops.ones([4]) auc, update_op = metrics.auc(labels, predictions, curve='PR') sess.run(variables.local_variables_initializer()) - self.assertAlmostEqual(1, sess.run(update_op), 6) + self.assertAlmostEqual(0.5, sess.run(update_op), 6) - self.assertAlmostEqual(1, auc.eval(), 6) + self.assertAlmostEqual(0.5, auc.eval(), 6) def np_auc(self, predictions, labels, weights): """Computes the AUC explicitly using Numpy. diff --git a/tensorflow/python/kernel_tests/reduction_ops_test.py b/tensorflow/python/kernel_tests/reduction_ops_test.py index 531478162971575739bbe37abfc57ca427ae22ae..d306d1b8d64f292dc299deee2e3c36981b933d1e 100644 --- a/tensorflow/python/kernel_tests/reduction_ops_test.py +++ b/tensorflow/python/kernel_tests/reduction_ops_test.py @@ -887,11 +887,7 @@ class AnyReductionTest(test.TestCase): class CountNonzeroReductionTest(test.TestCase): - def _compare(self, - x, - reduction_axes, - keepdims, - use_gpu=False, + def _compare(self, x, reduction_axes, keepdims, use_gpu=False, feed_dict=None): np_ans = (x != 0).astype(np.int32) if reduction_axes is None: diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index dc6e73bd5b7930d9292a4654734f55c6b29d4389..8503f3e0310125bb714942b32bbbf46596f9bddb 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -64,6 +64,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): 0, dtype=dtypes.int32)).run() + def testGPUInt64(self): + if not context.context().num_gpus(): + return + with context.eager_mode(), context.device("gpu:0"): + v = resource_variable_ops.ResourceVariable(1, dtype=dtypes.int64) + self.assertAllEqual(1, v.numpy()) + def testEagerNameNotIdentity(self): with context.eager_mode(): v0 = resource_variable_ops.ResourceVariable(1.0, name="a") @@ -162,14 +169,15 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testScatterAdd(self): - handle = resource_variable_ops.var_handle_op( - dtype=dtypes.int32, shape=[1, 1]) - self.evaluate(resource_variable_ops.assign_variable_op( - handle, constant_op.constant([[1]], dtype=dtypes.int32))) - self.evaluate(resource_variable_ops.resource_scatter_add( - handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) - read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) - self.assertEqual(self.evaluate(read), [[3]]) + with ops.device("cpu:0"): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[1, 1]) + self.evaluate(resource_variable_ops.assign_variable_op( + handle, constant_op.constant([[1]], dtype=dtypes.int32))) + self.evaluate(resource_variable_ops.resource_scatter_add( + handle, [0], constant_op.constant([[2]], dtype=dtypes.int32))) + read = resource_variable_ops.read_variable_op(handle, dtype=dtypes.int32) + self.assertEqual(self.evaluate(read), [[3]]) def testScatterUpdateString(self): handle = resource_variable_ops.var_handle_op( diff --git a/tensorflow/python/kernel_tests/stack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py index 347baf81148e9b747a9be4849912d154b220a084..2f27d1839b2218d0cc33d7278116186548ad3420 100644 --- a/tensorflow/python/kernel_tests/stack_op_test.py +++ b/tensorflow/python/kernel_tests/stack_op_test.py @@ -50,7 +50,7 @@ class StackOpTest(test.TestCase): # Convert [data[0], data[1], ...] separately to tensorflow # TODO(irving): Remove list() once we handle maps correctly xs = list(map(constant_op.constant, data)) - # Pack back into a single tensorflow tensor + # Stack back into a single tensorflow tensor c = array_ops.stack(xs) self.assertAllEqual(c.eval(), data) @@ -78,7 +78,7 @@ class StackOpTest(test.TestCase): for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2): for dtype in [np.bool, np.float32, np.int32, np.int64]: data = np.random.randn(*shape).astype(dtype) - # Pack back into a single tensorflow tensor directly using np array + # Stack back into a single tensorflow tensor directly using np array c = array_ops.stack(data) # This is implemented via a Const: self.assertEqual(c.op.type, "Const") @@ -223,7 +223,7 @@ class StackOpTest(test.TestCase): array_ops.stack(t, axis=-3) -class AutomaticPackingTest(test.TestCase): +class AutomaticStackingTest(test.TestCase): def testSimple(self): with self.test_session(use_gpu=True): diff --git a/tensorflow/python/kernel_tests/unique_op_test.py b/tensorflow/python/kernel_tests/unique_op_test.py index 6366d2e181c8cfabba8a78b664c25c85debc67ef..4498fd9fe9986c134b92aed192a6de6f06109bd9 100644 --- a/tensorflow/python/kernel_tests/unique_op_test.py +++ b/tensorflow/python/kernel_tests/unique_op_test.py @@ -133,6 +133,39 @@ class UniqueWithCountsTest(test.TestCase): v = [1 if x[i] == value.decode('ascii') else 0 for i in range(7000)] self.assertEqual(count, sum(v)) + def testInt32Axis(self): + for dtype in [np.int32, np.int64]: + x = np.array([[1, 0, 0], [1, 0, 0], [2, 0, 0]]) + with self.test_session() as sess: + y0, idx0, count0 = gen_array_ops._unique_with_counts_v2( + x, axis=np.array([0], dtype)) + tf_y0, tf_idx0, tf_count0 = sess.run([y0, idx0, count0]) + y1, idx1, count1 = gen_array_ops._unique_with_counts_v2( + x, axis=np.array([1], dtype)) + tf_y1, tf_idx1, tf_count1 = sess.run([y1, idx1, count1]) + self.assertAllEqual(tf_y0, np.array([[1, 0, 0], [2, 0, 0]])) + self.assertAllEqual(tf_idx0, np.array([0, 0, 1])) + self.assertAllEqual(tf_count0, np.array([2, 1])) + self.assertAllEqual(tf_y1, np.array([[1, 0], [1, 0], [2, 0]])) + self.assertAllEqual(tf_idx1, np.array([0, 1, 1])) + self.assertAllEqual(tf_count1, np.array([1, 2])) + + def testInt32V2(self): + # This test is only temporary, once V2 is used + # by default, the axis will be wrapped to allow `axis=None`. + x = np.random.randint(2, high=10, size=7000) + with self.test_session() as sess: + y, idx, count = gen_array_ops._unique_with_counts_v2( + x, axis=np.array([], np.int32)) + tf_y, tf_idx, tf_count = sess.run([y, idx, count]) + + self.assertEqual(len(x), len(tf_idx)) + self.assertEqual(len(tf_y), len(np.unique(x))) + for i in range(len(x)): + self.assertEqual(x[i], tf_y[tf_idx[i]]) + for value, count in zip(tf_y, tf_count): + self.assertEqual(count, np.sum(x == value)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/python/kernel_tests/unstack_op_test.py b/tensorflow/python/kernel_tests/unstack_op_test.py index 84818755766a435c873f30e96dc0080af4f78b84..1ee6e0866a6b1c7a9b641a95403d45213f5dc0b4 100644 --- a/tensorflow/python/kernel_tests/unstack_op_test.py +++ b/tensorflow/python/kernel_tests/unstack_op_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Functional tests for Unpack Op.""" +"""Functional tests for Unstack Op.""" from __future__ import absolute_import from __future__ import division @@ -49,7 +49,7 @@ class UnstackOpTest(test.TestCase): data = np.random.randn(*shape).astype(dtype) # Convert data to a single tensorflow tensor x = constant_op.constant(data) - # Unpack into a list of tensors + # Unstack into a list of tensors cs = array_ops.unstack(x, num=shape[0]) self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) @@ -66,7 +66,7 @@ class UnstackOpTest(test.TestCase): data = np.random.randn(*shape).astype(dtype) # Convert data to a single tensorflow tensor x = constant_op.constant(data) - # Unpack into a list of tensors + # Unstack into a list of tensors cs = array_ops.unstack(x, num=shape[0]) self.assertEqual(type(cs), list) self.assertEqual(len(cs), shape[0]) diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index 0d78ef25f24b1796f5eb865bcf794cb46a9a031e..8314c4aa87a5b54effc44c371703267517ffa07d 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -263,6 +263,8 @@ class Layer(object): return # Updates already applied when in eager mode. updates = _to_list(updates) + updates = [x if isinstance(x, ops.Operation) + else ops.convert_to_tensor(x) for x in updates] self._updates += updates if inputs is None: for u in updates: @@ -730,12 +732,10 @@ class Layer(object): activity_regularization = self._activity_regularizer(output) self.add_loss(activity_regularization, inputs=inputs) - if not in_deferred_mode: - # TODO(fchollet): consider how masking will work with deferred mode. - # Handle mask computation and propagation to the next layer. + # TODO(fchollet): consider enabling masking for Eager mode. if hasattr(self, 'compute_mask'): output_mask = self.compute_mask(inputs, previous_mask) - if isinstance(outputs, list): + if isinstance(outputs, (list, tuple)): if output_mask is None: output_mask = [None for _ in range(len(outputs))] for x, m in zip(outputs, output_mask): diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py index 689046fe78832ebeb2a44a59797dc57396e9ce16..bb10fe5e8bfd26e4877fb6aef73980a30f62bb5d 100644 --- a/tensorflow/python/layers/convolutional.py +++ b/tensorflow/python/layers/convolutional.py @@ -1096,10 +1096,10 @@ class SeparableConv1D(_SeparableConv): def call(self, inputs): if self.data_format == 'channels_last': - strides = (1, 1) + self.strides + (1,) + strides = (1,) + self.strides * 2 + (1,) spatial_start_dim = 1 else: - strides = (1, 1, 1) + self.strides + strides = (1, 1) + self.strides * 2 spatial_start_dim = 2 # Explicitly broadcast inputs and kernels to 4D. diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py index ec4fca78f046aff0ec6f6e65d5397d2649b329f1..6970bf9234f5a31ee8093069ac1c933bcdb6f103 100644 --- a/tensorflow/python/layers/core.py +++ b/tensorflow/python/layers/core.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import standard_ops from tensorflow.python.util.tf_export import tf_export @@ -291,13 +292,7 @@ class Dropout(base.Layer): # shapes with dynamically sized inputs. if self.noise_shape is None: return self.noise_shape - - symbolic_shape = array_ops.shape(inputs) - noise_shape = [ - symbolic_shape[axis] if shape is None else shape - for axis, shape in enumerate(self.noise_shape) - ] - return noise_shape + return nn_ops._get_noise_shape(inputs, self.noise_shape) def call(self, inputs, training=False): diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py index 1555846efde812b9e31f48315decaf1f86aa4f70..13a8e8e39caaf9c74d1c7d0ea4d6856f725256fd 100644 --- a/tensorflow/python/layers/layers.py +++ b/tensorflow/python/layers/layers.py @@ -68,7 +68,6 @@ from tensorflow.python.util.all_util import remove_undocumented # Base objects. from tensorflow.python.layers.base import Layer from tensorflow.python.layers.base import InputSpec -from tensorflow.python.layers.network import Input # Core layers. from tensorflow.python.layers.core import Dense diff --git a/tensorflow/python/layers/maxout.py b/tensorflow/python/layers/maxout.py deleted file mode 100644 index 765a1c4fdafdfdc5d3ea6629d4d9290d8b658902..0000000000000000000000000000000000000000 --- a/tensorflow/python/layers/maxout.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# pylint: disable=unused-import,g-bad-import-order -"""Contains the maxout layer -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.eager import context -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import gen_array_ops - -from tensorflow.python.layers import base - - -def maxout(inputs, num_units, axis=-1, name=None): - """Adds a maxout op from https://arxiv.org/abs/1302.4389 - - "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron - Courville, - Yoshua Bengio - - Usually the operation is performed in the filter/channel dimension. This can - also be - used after fully-connected layers to reduce number of features. - - Arguments: - inputs: Tensor input - num_units: Specifies how many features will remain after maxout in the `axis` - dimension - (usually channel). This must be multiple of number of `axis`. - axis: The dimension where max pooling will be performed. Default is the - last dimension. - name: Optional scope for name_scope. - - Returns: - A `Tensor` representing the results of the pooling operation. - - Raises: - ValueError: if num_units is not multiple of number of features. - """ - return MaxOut(num_units=num_units, axis=axis, name=name)(inputs) - - -class MaxOut(base.Layer): - """Adds a maxout op from https://arxiv.org/abs/1302.4389 - - "Maxout Networks" Ian J. Goodfellow, David Warde-Farley, Mehdi Mirza, Aaron - Courville, Yoshua - Bengio - - Usually the operation is performed in the filter/channel dimension. This can - also be - used after fully-connected layers to reduce number of features. - - Arguments: - inputs: Tensor input - num_units: Specifies how many features will remain after maxout in the - `axis` dimension - (usually channel). - This must be multiple of number of `axis`. - axis: The dimension where max pooling will be performed. Default is the - last dimension. - name: Optional scope for name_scope. - - Returns: - A `Tensor` representing the results of the pooling operation. - - Raises: - ValueError: if num_units is not multiple of number of features. - """ - - def __init__(self, num_units, axis=-1, name=None, **kwargs): - super(MaxOut, self).__init__(name=name, trainable=False, **kwargs) - self.axis = axis - self.num_units = num_units - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs) - shape = inputs.get_shape().as_list() - num_channels = shape[self.axis] - if num_channels % self.num_units: - raise ValueError('number of features({}) is not ' - 'a multiple of num_units({})'.format( - num_channels, self.num_units)) - shape[self.axis] = -1 - shape += [num_channels // self.num_units] - - # Dealing with batches with arbitrary sizes - for i in range(len(shape)): - if shape[i] is None: - shape[i] = gen_array_ops.shape(inputs)[i] - outputs = math_ops.reduce_max( - gen_array_ops.reshape(inputs, shape), -1, keepdims=False) - - return outputs diff --git a/tensorflow/python/layers/maxout_test.py b/tensorflow/python/layers/maxout_test.py deleted file mode 100644 index 26acac57c41da759f288f255c0cd523f9c6b1dbd..0000000000000000000000000000000000000000 --- a/tensorflow/python/layers/maxout_test.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= - -# pylint: disable=unused-import,g-bad-import-order - - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.layers import maxout -from tensorflow.python.layers import convolutional as conv_layers -from tensorflow.python.layers import core as core_layers - -from tensorflow.python.ops import random_ops -from tensorflow.python.platform import test -import numpy as np - -""" -Contains the maxout layer tests -""" - - -class MaxOutTest(test.TestCase): - def test_simple(self): - inputs = random_ops.random_uniform((64, 10, 36), seed=1) - graph = maxout.maxout(inputs, num_units=3) - self.assertEqual(graph.get_shape().as_list(), [64, 10, 3]) - - def test_fully_connected(self): - inputs = random_ops.random_uniform((64, 50), seed=1) - graph = core_layers.dense(inputs, 50) - graph = maxout.maxout(graph, num_units=10) - self.assertEqual(graph.get_shape().as_list(), [64, 10]) - - def test_nchw(self): - inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) - graph = conv_layers.conv2d(inputs, 10, 3, padding="SAME") - graph = maxout.maxout(graph, num_units=1) - self.assertEqual(graph.get_shape().as_list(), [10, 100, 100, 1]) - - def test_invalid_shape(self): - inputs = random_ops.random_uniform((10, 100, 100, 3), seed=1) - graph = conv_layers.conv2d(inputs, 3, 10, strides=(1, 1)) - with self.assertRaisesRegexp(ValueError, 'number of features'): - graph = maxout.maxout(graph, num_units=2) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/layers/network.py b/tensorflow/python/layers/network.py deleted file mode 100644 index 499f53d21bebfe3572ac9148911962eb868812bc..0000000000000000000000000000000000000000 --- a/tensorflow/python/layers/network.py +++ /dev/null @@ -1,998 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================= -"""Contains Network, a composition of layers.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - - -from tensorflow.python.eager import context -from tensorflow.python.estimator import util as estimator_util -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base -from tensorflow.python.layers import utils as layers_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.util import nest -from tensorflow.python.util.tf_export import tf_export - - -class InputLayer(base.Layer): - """Layer to be used as an entry point into a Network (a graph of layers). - - It can either wrap an existing tensor (pass an `input_tensor` argument) - or create its a placeholder tensor (pass arguments `input_shape` - as well as `dtype`). - - It is generally recommend to use the functional layer API via `Input`, - (which creates an `InputLayer`) without directly using `InputLayer`. - - Arguments: - input_shape: Shape tuple (not including the batch axis), or `TensorShape` - instance (not including the batch axis). - batch_size: Optional input batch size (integer or None). - dtype: Datatype of the input. - input_tensor: Optional tensor to use as layer input - instead of creating a placeholder. - sparse: Boolean, whether the placeholder created - is meant to be sparse. - name: Name of the layer (string). - - Raises: - RuntimeError: If created in Eager mode. - """ - - def __init__(self, - input_shape=None, - batch_size=None, - dtype=dtypes.float32, - input_tensor=None, - sparse=False, - name=None): - super(InputLayer, self).__init__(dtype=dtype, name=name) - self.built = True - self.sparse = sparse - self.batch_size = batch_size - - if isinstance(input_shape, tensor_shape.TensorShape): - input_shape = tuple(input_shape.as_list()) - - if input_tensor is None: - if input_shape is not None: - batch_input_shape = (batch_size,) + tuple(input_shape) - else: - batch_input_shape = None - - if context.in_eager_mode(): - # In eager mode, create a temporary placeholder to call the layer on. - input_tensor = base._DeferredTensor( # pylint: disable=protected-access - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - # In graph mode, create a graph placeholder to call the layer on. - if sparse: - input_tensor = array_ops.sparse_placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - else: - input_tensor = array_ops.placeholder( - shape=batch_input_shape, - dtype=dtype, - name=self.name) - - # For compatibility with Keras API. - self.is_placeholder = True - self._batch_input_shape = batch_input_shape - else: - # For compatibility with Keras API. - self.is_placeholder = False - self._batch_input_shape = tuple(input_tensor.get_shape().as_list()) - - # Create an input node to add to self.outbound_node - # and set output_tensors' _keras_history. - input_tensor._keras_history = (self, 0, 0) # pylint: disable=protected-access - base.Node( - self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=[input_tensor], - output_tensors=[input_tensor]) - - -@tf_export('layers.Input') -def Input( # pylint: disable=invalid-name - shape=None, - batch_size=None, - name=None, - dtype=dtypes.float32, - sparse=False, - tensor=None): - """`Input()` is used to instantiate an input tensor for use with a `Network`. - - For instance, if a, b and c are tensors created via `Input`, - it becomes possible to do: - - `network = Network(inputs=[a, b], outputs=c)` - - Example: - - ```python - # This is a logistic regression - x = tf.layers.Input(shape=(32,)) - y = tf.layers.Dense(16, activation='softmax')(x) - network = tf.layers.Network(x, y) - ``` - - Arguments: - shape: A shape tuple (integer), not including the batch size. - For instance, `shape=(32,)` indicates that the expected input - will be batches of 32-dimensional vectors. - batch_size: Optional input batch size (integer or None). - name: An optional name string for the layer. - Should be unique in a model (do not reuse the same name twice). - It will be autogenerated if it isn't provided. - dtype: The data type expected by the input, as a string - (`float32`, `float64`, `int32`...) - sparse: A boolean specifying whether the placeholder - to be created is sparse. - tensor: Optional existing tensor to wrap into the `Input` layer. - If set, the layer will not create a placeholder tensor. - - Returns: - A tensor: either a new placeholder (with history metadata) or - `tensor` (if passed), with added history metadata. - - Raises: - RuntimeError: If called in Eager mode. - """ - input_layer = InputLayer( - input_shape=shape, - batch_size=batch_size, - name=name, - dtype=dtype, - sparse=sparse, - input_tensor=tensor) - # Return tensor including `_keras_history` metadata. - # Note that in this case train_output and test_output are the same pointer. - outputs = input_layer._inbound_nodes[0].output_tensors # pylint: disable=protected-access - if len(outputs) == 1: - return outputs[0] - else: - return outputs - - -class GraphNetwork(base.Layer): - """A GraphNetwork is a directed acyclic graph of layers. - - It is the topological form of a `tf.keras.models.Model`. A `Model` is simply a - `GraphNetwork` with added training/evaluation routines. - - A `GraphNetwork` instance implements the full `Layer` API. In particular, a - `GraphNetwork` can be called on new inputs. - - Example: - - ```python - # This is a logistic regression - x = tf.layers.Input(shape=(32,)) - y = tf.layers.Dense(16, activation='softmax')(x) - network = tf.layers.GraphNetwork(x, y) - - # It is then possible to call the network on compatible inputs: - z = tf.layers.Input(shape=(32,)) - w = network(z) - - # It is possible to retrieve the same properties as a layer: - weights = network.trainable_weights - ``` - - Arguments: - inputs: Input tensor or list of input tensors. - Must come from `tf.layers.Input`. - output: Output tensor or list of output tensors. Must come from - tf.layers Layers or Keras layers. - name: Optional name of the model (string). - - Attributes: - GraphNetwork has the same attributes as Layer. On top of it, it also has: - - layers: a list of the children layers of the network, - a list of layer instances, ordered from "earlier in the graph" - to "later in the graph". - - Methods: - GraphNetwork has the same methods as Layer. On top of it, it also has: - - get_layer: retrieves a child layer by name or index in the graph. - - Raises: - RuntimeError: If created in Eager mode. - """ - - def __init__(self, inputs, outputs, name=None): # pylint: disable=super-init-not-called - if context.in_eager_mode(): - # TODO(fchollet): check that all inputs and outputs are DeferredTensors. - pass - - self._init_set_name(name) - self._activity_regularizer = None - with vs.variable_scope( - None, default_name=self._base_name) as captured_scope: - self._scope = captured_scope - call_fn_args = estimator_util.fn_args(self.call) - self._compute_previous_mask = ('mask' in call_fn_args or - hasattr(self, 'compute_mask')) - self._call_has_scope_arg = 'scope' in call_fn_args - - # This acts just like the `trainable` attribute of any layer instance. - # It does not affect users of the underlying layers, only users of the - # GraphNetwork instance. - self.trainable = True - # A GraphNetwork does not create weights of its own, thus it is already - # built. - self.built = True - # A GraphNetwork does not create weights of its own, thus has no dtype. - self._dtype = None - # The following are implemented as property functions: - # self.trainable_weights - # self.non_trainable_weights - # self.input_spec - - # Private attributes to implement compatibility with Layer. - self._updates = [] - self._losses = [] - self._scope = None - self._reuse = None - self._graph = ops.get_default_graph() - - # GraphNetwork-specific properties. - if isinstance(inputs, (list, tuple)): - self.inputs = list(inputs) # Tensor or list of tensors. - else: - self.inputs = [inputs] - if isinstance(outputs, (list, tuple)): - self.outputs = list(outputs) - else: - self.outputs = [outputs] - # All layers in order of horizontal graph traversal. - # Entries are unique. Includes input and output layers. - self.layers = [] - - # Check for redundancy in inputs. - if len(set(self.inputs)) != len(self.inputs): - raise ValueError('The list of inputs passed to the model ' - 'is redundant. ' - 'All inputs should only appear once.' - ' Found: ' + str(self.inputs)) - - # # List of initial layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._input_layers = [] - # self._input_layers_node_indices = [] - # self._input_layers_tensor_indices = [] - # # list of layers (1 to 1 mapping with self.inputs, - # # hence the same layer might appear twice) - # self._output_layers = [] - # self._output_layers_node_indices = [] - # self._output_layers_tensor_indices = [] - - self._input_layers = [] - self._output_layers = [] - self._input_coordinates = [] - self._output_coordinates = [] - - # This is for performance optimization when calling the GraphNetwork on new - # inputs. Every time the GraphNetwork is called on a set on input tensors, - # we compute the output tensors, output masks and output shapes in one pass, - # then cache them here. When any of these outputs is queried later, we - # retrieve it from there instead of recomputing it. - self._output_mask_cache = {} - self._output_tensor_cache = {} - self._output_shape_cache = {} - - # User-provided arguments validation. - for x in self.inputs: - # Check that x has appropriate `_keras_history` metadata. - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise ValueError('Input tensors to a ' + cls_name + ' ' + - 'must come from `tf.layers.Input`. ' - 'Received: ' + str(x) + - ' (missing previous layer metadata).') - # Check that x is an input tensor. - # pylint: disable=protected-access - layer, node_index, tensor_index = x._keras_history - if len(layer._inbound_nodes) > 1 or ( - layer._inbound_nodes and layer._inbound_nodes[0].inbound_layers): - cls_name = self.__class__.__name__ - logging.warning(cls_name + ' inputs must come from ' - '`tf.layers.Input` (thus holding past layer metadata), ' - 'they cannot be the output of ' - 'a previous non-Input layer. ' - 'Here, a tensor specified as ' - 'input to "' + self.name + '" was not an Input tensor, ' - 'it was generated by layer ' + layer.name + '.\n' - 'Note that input tensors are ' - 'instantiated via `tensor = tf.layers.Input(shape)`.\n' - 'The tensor that caused the issue was: ' + str(x.name)) - # pylint: enable=protected-access - for x in self.outputs: - if not hasattr(x, '_keras_history'): - cls_name = self.__class__.__name__ - raise ValueError('Output tensors to a ' + cls_name + ' must be ' - 'the output of a TensorFlow `Layer` ' - '(thus holding past layer metadata). Found: ' + str(x)) - - # Build self._output_layers: - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - self._output_layers.append(layer) - self._output_coordinates.append((layer, node_index, tensor_index)) - - # Build self._input_layers: - for x in self.inputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - # It's supposed to be an input layer, so only one node - # and one tensor output. - assert node_index == 0 - assert tensor_index == 0 - self._input_layers.append(layer) - self._input_coordinates.append((layer, node_index, tensor_index)) - - # Network_nodes: set of nodes included in the graph - # (not all nodes included in the layers - # are relevant to the current graph). - network_nodes = set() # ids of all nodes relevant to the GraphNetwork - nodes_depths = {} # dict {node: depth value} - layers_depths = {} # dict {layer: depth value} - layer_indices = {} # dict {layer: index in traversal} - nodes_in_decreasing_depth = [] - - def build_map_of_graph(tensor, - finished_nodes, - nodes_in_progress, - layer, - node_index, - tensor_index): - """Builds a map of the graph of layers. - - This recursively updates the map `layer_indices`, - the list `nodes_in_decreasing_depth` and the set `network_nodes`. - - Arguments: - tensor: Some tensor in a graph. - finished_nodes: Set of nodes whose subgraphs have been traversed - completely. Useful to prevent duplicated work. - nodes_in_progress: Set of nodes that are currently active on the - recursion stack. Useful to detect cycles. - layer: Layer from which `tensor` comes from. If not provided, - will be obtained from `tensor._keras_history`. - node_index: Node index from which `tensor` comes from. - tensor_index: Tensor_index from which `tensor` comes from. - - Raises: - ValueError: if a cycle is detected. - """ - node = layer._inbound_nodes[node_index] # pylint: disable=protected-access - - # Prevent cycles. - if node in nodes_in_progress: - raise ValueError('The tensor ' + str(tensor) + ' at layer "' + - layer.name + '" is part of a cycle.') - - # Don't repeat work for shared subgraphs - if node in finished_nodes: - return - - node_key = _make_node_key(layer.name, node_index) - # Update network_nodes. - network_nodes.add(node_key) - - # Store the traversal order for layer sorting. - if layer not in layer_indices: - layer_indices[layer] = len(layer_indices) - - nodes_in_progress.add(node) - - # Propagate to all previous tensors connected to this node. - for i in range(len(node.inbound_layers)): - x = node.input_tensors[i] - layer = node.inbound_layers[i] - node_index = node.node_indices[i] - tensor_index = node.tensor_indices[i] - build_map_of_graph(x, finished_nodes, nodes_in_progress, layer, - node_index, tensor_index) - - finished_nodes.add(node) - nodes_in_progress.remove(node) - nodes_in_decreasing_depth.append(node) - - finished_nodes = set() - nodes_in_progress = set() - for x in self.outputs: - layer, node_index, tensor_index = x._keras_history # pylint: disable=protected-access - build_map_of_graph(x, finished_nodes, nodes_in_progress, - layer=layer, - node_index=node_index, - tensor_index=tensor_index) - - for node in reversed(nodes_in_decreasing_depth): - # If the depth is not set, the node has no outbound nodes (depth 0). - depth = nodes_depths.setdefault(node, 0) - - # Update the depth of the corresponding layer - previous_depth = layers_depths.get(node.outbound_layer, 0) - # If we've seen this layer before at a higher depth, - # we should use that depth instead of the node depth. - # This is necessary for shared layers that have inputs at different - # depth levels in the graph. - depth = max(depth, previous_depth) - layers_depths[node.outbound_layer] = depth - nodes_depths[node] = depth - - # Update the depth of inbound nodes. - # The "depth" of a node is the max of the depths - # of all layers it is connected to. - for i in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[i] - node_index = node.node_indices[i] - inbound_node = inbound_layer._inbound_nodes[node_index] # pylint: disable=protected-access - previous_depth = nodes_depths.get(inbound_node, 0) - nodes_depths[inbound_node] = max(depth + 1, previous_depth) - - # Build a dict {depth: list of nodes with this depth} - nodes_by_depth = {} - for node, depth in nodes_depths.items(): - if depth not in nodes_by_depth: - nodes_by_depth[depth] = [] - nodes_by_depth[depth].append(node) - - # Build a dict {depth: list of layers with this depth} - layers_by_depth = {} - for layer, depth in layers_depths.items(): - if depth not in layers_by_depth: - layers_by_depth[depth] = [] - layers_by_depth[depth].append(layer) - - # Get sorted list of layer depths. - depth_keys = list(layers_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Set self.layers and self._layers_by_depth. - layers = [] - for depth in depth_keys: - layers_for_depth = layers_by_depth[depth] - # GraphNetwork.layers needs to have a deterministic order: - # here we order them by traversal order. - layers_for_depth.sort(key=lambda x: layer_indices[x]) - layers.extend(layers_for_depth) - self.layers = layers - self._layers_by_depth = layers_by_depth - - # Get sorted list of node depths. - depth_keys = list(nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - - # Check that all tensors required are computable. - # computable_tensors: all tensors in the graph - # that can be computed from the inputs provided. - computable_tensors = [] - for x in self.inputs: - computable_tensors.append(x) - - layers_with_complete_input = [] # To provide a better error msg. - for depth in depth_keys: - for node in nodes_by_depth[depth]: - layer = node.outbound_layer - if layer: - for x in node.input_tensors: - if x not in computable_tensors: - raise ValueError('Graph disconnected: ' - 'cannot obtain value for tensor ' + str(x) + - ' at layer "' + layer.name + '". ' - 'The following previous layers ' - 'were accessed without issue: ' + - str(layers_with_complete_input)) - for x in node.output_tensors: - computable_tensors.append(x) - layers_with_complete_input.append(layer.name) - - # Keep track of the network's nodes. - self._network_nodes = network_nodes - self._nodes_by_depth = nodes_by_depth - - # Ensure name unicity, which will be crucial for serialization - # (since serialized nodes refer to layers by their name). - all_names = [layer.name for layer in self.layers] - for name in all_names: - if all_names.count(name) != 1: - raise ValueError('The name "' + name + '" is used ' + - str(all_names.count(name)) + ' times in the model. ' - 'All layer names should be unique.') - - # Layer parameters. - # The new network starts with a single inbound node - # for its inputs, and no outbound nodes. - self._outbound_nodes = [] # Will be appended to by future calls to __call__ - self._inbound_nodes = [ - ] # Will be appended to below, and by future calls to __call__ - # Create the node linking internal inputs to internal outputs. - base.Node( - outbound_layer=self, - inbound_layers=[], - node_indices=[], - tensor_indices=[], - input_tensors=self.inputs, - output_tensors=self.outputs) - - def get_layer(self, name=None, index=None): - """Retrieves a layer based on either its name (unique) or index. - - Indices are based on order of horizontal graph traversal (bottom-up). - - Arguments: - name: String, name of layer. - index: Integer, index of layer. - - Returns: - A layer instance. - - Raises: - ValueError: In case of invalid layer name or index. - """ - # TODO(fchollet): We could build a dictionary based on layer names - # since they are constant, but we have not done that yet. - if index is not None: - if len(self.layers) <= index: - raise ValueError('Was asked to retrieve layer at index ' + str(index) + - ' but model only has ' + str(len(self.layers)) + - ' layers.') - else: - return self.layers[index] - else: - if not name: - raise ValueError('Provide either a layer name or layer index.') - for layer in self.layers: - if layer.name == name: - return layer - raise ValueError('No such layer: ' + name) - - @property - def stateful(self): - return any([(hasattr(layer, 'stateful') and layer.stateful) - for layer in self.layers]) - - @property - def updates(self): - """Retrieve the network's updates. - - Will only include updates that are either - unconditional, or conditional on inputs to this model - (e.g. will not include updates that were created by layers of this model - outside of the model). - - Effectively, `network.updates` behaves like `layer.updates`. - - Concrete example: - - ```python - bn = keras.layers.BatchNormalization() - x1 = keras.layers.Input(shape=(10,)) - _ = bn(x1) # This creates 2 updates. - - x2 = keras.layers.Input(shape=(10,)) - y2 = bn(x2) # This creates 2 more updates. - - # The BN layer has now 4 updates. - self.assertEqual(len(bn.updates), 4) - - # Let's create a model from x2 to y2. - model = keras.models.Model(x2, y2) - - # The model does not list all updates from its underlying layers, - # but only the updates that are relevant to it. Updates created by layers - # outside of the model are discarded. - self.assertEqual(len(model.updates), 2) - - # If you keep calling the model, you append to its updates, just like - # what happens for a layer. - x3 = keras.layers.Input(shape=(10,)) - y3 = model(x3) - self.assertEqual(len(model.updates), 4) - - # But if you call the inner BN layer independently, you don't affect - # the model's updates. - x4 = keras.layers.Input(shape=(10,)) - _ = bn(x4) - self.assertEqual(len(model.updates), 4) - ``` - - Returns: - A list of update ops. - """ - if not self.trainable and not self.stateful: - return [] - - updates = [] - for layer in self.layers: - updates += layer.updates - - # `updates` might contain irrelevant updates, so it needs to be filtered - # with respect to inputs the model has been called on. - relevant_inputs = [] - for i in range(len(self._inbound_nodes)): - inputs = self.get_input_at(i) - if isinstance(inputs, list): - relevant_inputs += inputs - else: - relevant_inputs.append(inputs) - reachable = layers_util.get_reachable_from_inputs(relevant_inputs, updates) - relevant_conditional_updates = [x for x in updates if x in reachable] - unconditional_updates = [ - x for x in updates if x._unconditional_update] # pylint: disable=protected-access - # A layer could be used multiple times in a nested structure, - # so the updates list must be de-duped. - return list(set( - relevant_conditional_updates + unconditional_updates + self._updates)) - - @property - def losses(self): - """Retrieve the network's losses. - - Will only include losses that are either - unconditional, or conditional on inputs to this model - (e.g. will not include losses that depend on tensors - that aren't inputs to this model). - - Returns: - A list of loss tensors. - """ - losses = [] - if context.in_eager_mode(): - for layer in self.layers: - losses += layer.losses - return losses - - for layer in self.layers: - losses += layer.losses - - relevant_inputs = [] - for i in range(len(self._inbound_nodes)): - inputs = self.get_input_at(i) - if isinstance(inputs, list): - relevant_inputs += inputs - else: - relevant_inputs.append(inputs) - reachable = layers_util.get_reachable_from_inputs(relevant_inputs, losses) - relevant_conditional_losses = [x for x in losses if x in reachable] - unconditional_losses = [ - x for x in losses if x._unconditional_loss] # pylint: disable=protected-access - return list(set( - relevant_conditional_losses + unconditional_losses + self._losses)) - - @property - def trainable_weights(self): - if not self.trainable: - return [] - weights = [] - for layer in self.layers: - weights += layer.trainable_weights - return weights - - @property - def non_trainable_weights(self): - weights = [] - for layer in self.layers: - weights += layer.non_trainable_weights - if not self.trainable: - trainable_weights = [] - for layer in self.layers: - trainable_weights += layer.trainable_weights - return trainable_weights + weights - return weights - - @property - def input_spec(self): - """Gets the network's input specs. - - Returns: - A list of `InputSpec` instances (one per input to the model) - or a single instance if the model has only one input. - """ - specs = [] - for layer in self._input_layers: - if layer.input_spec is None: - specs.append(None) - else: - if not isinstance(layer.input_spec, list): - raise TypeError('Layer ' + layer.name + - ' has an input_spec attribute that ' - 'is not a list. We expect a list. ' - 'Found input_spec = ' + str(layer.input_spec)) - specs += layer.input_spec - if len(specs) == 1: - return specs[0] - return specs - - def call(self, inputs, mask=None): - """Call the model on new inputs. - - In this case `call` just reapplies - all ops in the graph to the new inputs - (e.g. build a new computational graph from the provided inputs). - - Arguments: - inputs: A tensor or list of tensors. - mask: A mask or list of masks. A mask can be - either a tensor or None (no mask). - - Returns: - A tensor if there is a single output, or - a list of tensors if there are more than one outputs. - """ - inputs = nest.flatten(inputs) - if mask is None: - masks = [None for _ in range(len(inputs))] - else: - masks = nest.flatten(mask) - - if context.in_graph_mode(): - # Try to retrieve cached outputs if the layer has already been called - # on these exact inputs. - cache_key = (layers_util.object_list_uid(inputs) - + '_' + layers_util.object_list_uid(masks)) - if cache_key in self._output_tensor_cache: - # Cache hit. - return self._output_tensor_cache[cache_key] - # Actually apply the network graph to the new inputs. - outputs, _ = self._run_internal_graph(inputs, masks) - return outputs - - def compute_output_shape(self, input_shape): - if isinstance(input_shape, list): - input_shapes = [] - for shape in input_shape: - if shape is not None: - input_shapes.append(tuple(tensor_shape.TensorShape(shape).as_list())) - else: - input_shapes.append(None) - else: - if input_shape is not None: - input_shapes = [tuple(tensor_shape.TensorShape(input_shape).as_list())] - else: - input_shapes = [None] - - if len(input_shapes) != len(self._input_layers): - raise ValueError('Invalid input_shape argument ' + str(input_shape) + - ': model has ' + str(len(self._input_layers)) + - ' tensor inputs.') - - cache_key = layers_util.object_list_uid(input_shapes) - if cache_key not in self._output_shape_cache: - # Cache miss. We have to run the network graph manually (recursive calls - # to `compute_output_shape`). - layers_to_output_shapes = {} - for i in range(len(input_shapes)): - layer = self._input_layers[i] - input_shape = input_shapes[i] - # It's an input layer: then `compute_output_shape` is identity, - # and there is only one node and one tensor output. - shape_key = layer.name + '_0_0' - layers_to_output_shapes[shape_key] = input_shape - - depth_keys = list(self._nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - # Iterate over nodes, by depth level. - if len(depth_keys) > 1: - for depth in depth_keys: - nodes = self._nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - if layer in self._input_layers: - # We've already covered the input layers - # a few lines above. - continue - # Potentially redundant list, - # same size as node.input_tensors. - input_shapes = [] - for j in range(len(node.inbound_layers)): - inbound_layer = node.inbound_layers[j] - node_index = node.node_indices[j] - tensor_index = node.tensor_indices[j] - shape_key = inbound_layer.name + '_%s_%s' % (node_index, - tensor_index) - input_shape = layers_to_output_shapes[shape_key] - input_shapes.append(input_shape) - - if len(input_shapes) == 1: - output_shape = layer.compute_output_shape(input_shapes[0]) - else: - output_shape = layer.compute_output_shape(input_shapes) - if isinstance(output_shape, list): - output_shapes = [ - tuple(tensor_shape.TensorShape(shape).as_list()) - for shape in output_shape - ] - else: - output_shapes = [ - tuple(tensor_shape.TensorShape(output_shape).as_list()) - ] - - node_index = layer._inbound_nodes.index(node) # pylint: disable=protected-access - for j in range(len(output_shapes)): - shape_key = layer.name + '_%s_%s' % (node_index, j) - layers_to_output_shapes[shape_key] = output_shapes[j] - - # Read final output shapes from layers_to_output_shapes. - output_shapes = [] - for i in range(len(self._output_layers)): - layer, node_index, tensor_index = self._output_coordinates[i] - shape_key = layer.name + '_%s_%s' % (node_index, tensor_index) - output_shapes.append(layers_to_output_shapes[shape_key]) - # Store in cache. - self._output_shape_cache[cache_key] = output_shapes - else: - # Cache hit. - output_shapes = self._output_shape_cache[cache_key] - - if isinstance(output_shapes, list): - if len(output_shapes) == 1: - return tensor_shape.TensorShape(output_shapes[0]) - else: - return [tensor_shape.TensorShape(shape) for shape in output_shapes] - else: - return tensor_shape.TensorShape(output_shapes) - - def _run_internal_graph(self, inputs, masks=None): - """Computes output tensors for new inputs. - - # Note: - - Expects `inputs` to be a list (potentially with 1 element). - - Can be run on non-Keras tensors. - - Arguments: - inputs: List of tensors - masks: List of masks (tensors or None). - - Returns: - Three lists: output_tensors, output_masks, output_shapes - """ - # Note: masking support is relevant mainly for Keras. - # It cannot be factored out without having the fully reimplement the network - # calling logic on the Keras side. We choose to incorporate it in - # GraphNetwork because 1) it may be useful to fully support in tf.layers in - # the future and 2) Keras is a major user of GraphNetwork. If you don't - # use masking, it does not interfere with regular behavior at all and you - # can ignore it. - if masks is None: - masks = [None for _ in range(len(inputs))] - - # Dictionary mapping reference tensors to tuples - # (computed tensor, compute mask) - # we assume a 1:1 mapping from tensor to mask - # TODO(fchollet): raise exception when a `.compute_mask()` call - # does not return a list the same size as `call` - tensor_map = {} - for x, y, mask in zip(self.inputs, inputs, masks): - tensor_map[str(id(x))] = (y, mask) - - depth_keys = list(self._nodes_by_depth.keys()) - depth_keys.sort(reverse=True) - for depth in depth_keys: - nodes = self._nodes_by_depth[depth] - for node in nodes: - # This is always a single layer, never a list. - layer = node.outbound_layer - reference_input_tensors = node.input_tensors - reference_output_tensors = node.output_tensors - - # If all previous input tensors are available in tensor_map, - # then call node.inbound_layer on them. - computed_data = [] # List of tuples (input, mask). - for x in reference_input_tensors: - if str(id(x)) in tensor_map: - computed_data.append(tensor_map[str(id(x))]) - - if len(computed_data) == len(reference_input_tensors): - # Call layer (reapplying ops to new inputs). - with ops.name_scope(layer.name): - if node.arguments: - kwargs = node.arguments - else: - kwargs = {} - if len(computed_data) == 1: - computed_tensor, computed_mask = computed_data[0] - # Ensure mask propagation if applicable. - if 'mask' in estimator_util.fn_args(layer.call): - if 'mask' not in kwargs: - kwargs['mask'] = computed_mask - - output_tensors = nest.flatten( - layer.call(computed_tensor, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = nest.flatten( - layer.compute_mask(computed_tensor, computed_mask)) - else: - output_masks = [None for _ in range(len(output_tensors))] - computed_tensors = [computed_tensor] - computed_masks = [computed_mask] - else: - computed_tensors = [x[0] for x in computed_data] - computed_masks = [x[1] for x in computed_data] - if 'mask' in estimator_util.fn_args(layer.call): - if 'mask' not in kwargs: - kwargs['mask'] = computed_masks - output_tensors = nest.flatten( - layer.call(computed_tensors, **kwargs)) - if hasattr(layer, 'compute_mask'): - output_masks = nest.flatten( - layer.compute_mask(computed_tensors, computed_masks)) - else: - output_masks = [None for _ in range(len(output_tensors))] - - if context.in_graph_mode(): - if layer.activity_regularizer is not None: - regularization_losses = [ - layer.activity_regularizer(x) for x in computed_tensors - ] - # Apply activity regularizer if any: - layer.add_loss(regularization_losses, computed_tensors) - - # Update tensor_map. - for x, y, mask in zip(reference_output_tensors, output_tensors, - output_masks): - tensor_map[str(id(x))] = (y, mask) - - output_tensors = [] - output_masks = [] - output_shapes = [] - for x in self.outputs: - assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x) - tensor, mask = tensor_map[str(id(x))] - output_shapes.append(layers_util.static_shape(x)) - output_tensors.append(tensor) - output_masks.append(mask) - - if len(output_tensors) == 1: - output_tensors = output_tensors[0] - if output_shapes is not None: - output_shapes = output_shapes[0] - if output_masks is not None: - output_masks = output_masks[0] - - if context.in_graph_mode(): - # Update cache; - # keys are based on ids on input tensors and inputs masks. - cache_key = (layers_util.object_list_uid(inputs) - + '_' + layers_util.object_list_uid(masks)) - self._output_tensor_cache[cache_key] = output_tensors - self._output_mask_cache[cache_key] = output_masks - - if output_shapes is not None: - input_shapes = [layers_util.static_shape(x) for x in inputs] - cache_key = layers_util.object_list_uid(input_shapes) - self._output_shape_cache[cache_key] = output_shapes - - return output_tensors, output_masks - - -def _make_node_key(layer_name, node_index): - return layer_name + '_ib-' + str(node_index) diff --git a/tensorflow/python/layers/network_test.py b/tensorflow/python/layers/network_test.py deleted file mode 100644 index f46ebdf2af10a44abd62470b7e85c7f2a04c6d57..0000000000000000000000000000000000000000 --- a/tensorflow/python/layers/network_test.py +++ /dev/null @@ -1,634 +0,0 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for tf.layers.network.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import test_util -from tensorflow.python.layers import base as base_layers -from tensorflow.python.layers import core as core_layers -from tensorflow.python.layers import network as network_layers -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.platform import test - - -class BaseLayerCompatibilityTest(test.TestCase): - - def test_get_updates(self): - - class MyLayer(base_layers.Layer): - - def build(self, input_shape): - self.a = self.add_variable('a', - (1, 1), - 'float32', - trainable=False) - self.b = self.add_variable('b', - (1, 1), - 'float32', - trainable=False) - self.add_update(state_ops.assign_add(self.a, [[1.]])) - self.built = True - - def call(self, inputs): - self.add_update(state_ops.assign_add(self.a, inputs), - inputs=True) - return inputs + 1 - - x1 = network_layers.Input(shape=(1,)) - layer = MyLayer() - _ = layer.apply(x1) - - self.assertEqual(len(layer.updates), 2) - self.assertEqual(len(layer.get_updates_for(x1)), 1) - self.assertEqual(len(layer.get_updates_for(None)), 1) - - x2 = network_layers.Input(shape=(1,)) - y2 = layer.apply(x2) - - self.assertEqual(len(layer.updates), 3) - self.assertEqual(len(layer.get_updates_for(x1)), 1) - self.assertEqual(len(layer.get_updates_for(x2)), 1) - self.assertEqual(len(layer.get_updates_for(None)), 1) - - network = network_layers.GraphNetwork(x2, y2) - self.assertEqual(len(network.updates), 2) - self.assertEqual(len(network.get_updates_for(x1)), 0) - self.assertEqual(len(network.get_updates_for(x2)), 1) - self.assertEqual(len(network.get_updates_for(None)), 1) - - x3 = network_layers.Input(shape=(1,)) - _ = layer.apply(x3) - self.assertEqual(len(network.updates), 2) - - x4 = network_layers.Input(shape=(1,)) - _ = network(x4) - self.assertEqual(len(network.updates), 3) - self.assertEqual(len(network.get_updates_for(x2)), 1) - self.assertEqual(len(network.get_updates_for(x4)), 1) - self.assertEqual(len(network.get_updates_for(None)), 1) - - network.add_update(state_ops.assign_add(layer.a, [[1]])) - self.assertEqual(len(network.updates), 4) - self.assertEqual(len(network.get_updates_for(None)), 2) - - network.add_update(state_ops.assign_add(layer.a, x4), inputs=True) - self.assertEqual(len(network.updates), 5) - self.assertEqual(len(network.get_updates_for(x4)), 2) - - def test_get_losses(self): - - class MyLayer(base_layers.Layer): - - def build(self, input_shape): - self.a = self.add_variable('a', - (1, 1), - 'float32', - trainable=False) - self.b = self.add_variable('b', - (1, 1), - 'float32', - trainable=False) - self.add_loss(math_ops.reduce_sum(self.a)) - self.built = True - - def call(self, inputs): - self.add_loss(math_ops.reduce_sum(inputs), - inputs=True) - return inputs + 1 - - x1 = network_layers.Input(shape=(1,)) - layer = MyLayer() - _ = layer.apply(x1) - - self.assertEqual(len(layer.losses), 2) - self.assertEqual(len(layer.get_losses_for(x1)), 1) - self.assertEqual(len(layer.get_losses_for(None)), 1) - - x2 = network_layers.Input(shape=(1,)) - y2 = layer.apply(x2) - - self.assertEqual(len(layer.losses), 3) - self.assertEqual(len(layer.get_losses_for(x1)), 1) - self.assertEqual(len(layer.get_losses_for(x2)), 1) - self.assertEqual(len(layer.get_losses_for(None)), 1) - - network = network_layers.GraphNetwork(x2, y2) - self.assertEqual(len(network.losses), 2) - self.assertEqual(len(network.get_losses_for(x1)), 0) - self.assertEqual(len(network.get_losses_for(x2)), 1) - self.assertEqual(len(network.get_losses_for(None)), 1) - - x3 = network_layers.Input(shape=(1,)) - _ = layer.apply(x3) - self.assertEqual(len(network.losses), 2) - - x4 = network_layers.Input(shape=(1,)) - _ = network(x4) - self.assertEqual(len(network.losses), 3) - self.assertEqual(len(network.get_losses_for(x2)), 1) - self.assertEqual(len(network.get_losses_for(x4)), 1) - self.assertEqual(len(network.get_losses_for(None)), 1) - - network.add_loss(math_ops.reduce_sum(layer.a)) - self.assertEqual(len(network.losses), 4) - self.assertEqual(len(network.get_losses_for(None)), 2) - - network.add_loss(math_ops.reduce_sum(x4), inputs=True) - self.assertEqual(len(network.losses), 5) - self.assertEqual(len(network.get_losses_for(x4)), 2) - - def testTopologicalAttributes(self): - # test layer attributes / methods related to cross-layer connectivity. - a = network_layers.Input(shape=(32,), name='input_a') - b = network_layers.Input(shape=(32,), name='input_b') - - # test input, output, input_shape, output_shape - test_layer = core_layers.Dense(16, name='test_layer') - a_test = test_layer(a) - self.assertEqual(test_layer.input, a) - self.assertEqual(test_layer.output, a_test) - self.assertEqual(test_layer.input_shape, (None, 32)) - self.assertEqual(test_layer.output_shape, (None, 16)) - - # test `get_*_at` methods - dense = core_layers.Dense(16, name='dense_1') - a_2 = dense(a) - b_2 = dense(b) - - self.assertEqual(dense.get_input_at(0), a) - self.assertEqual(dense.get_input_at(1), b) - self.assertEqual(dense.get_output_at(0), a_2) - self.assertEqual(dense.get_output_at(1), b_2) - self.assertEqual(dense.get_input_shape_at(0), (None, 32)) - self.assertEqual(dense.get_input_shape_at(1), (None, 32)) - self.assertEqual(dense.get_output_shape_at(0), (None, 16)) - self.assertEqual(dense.get_output_shape_at(1), (None, 16)) - - # Test invalid value for attribute retrieval. - with self.assertRaises(ValueError): - dense.get_input_at(2) - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.input - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.output - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.output_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - _ = new_dense.input_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - a = network_layers.Input(shape=(3, 32)) - a = network_layers.Input(shape=(5, 32)) - a_2 = dense(a) - b_2 = dense(b) - _ = new_dense.input_shape - with self.assertRaises(AttributeError): - new_dense = core_layers.Dense(16) - a = network_layers.Input(shape=(3, 32)) - a = network_layers.Input(shape=(5, 32)) - a_2 = dense(a) - b_2 = dense(b) - _ = new_dense.output_shape - - def testTopologicalAttributesMultiOutputLayer(self): - - class PowersLayer(base_layers.Layer): - - def call(self, inputs): - return [inputs**2, inputs**3] - - x = network_layers.Input(shape=(32,)) - test_layer = PowersLayer() - p1, p2 = test_layer(x) # pylint: disable=not-callable - - self.assertEqual(test_layer.input, x) - self.assertEqual(test_layer.output, [p1, p2]) - self.assertEqual(test_layer.input_shape, (None, 32)) - self.assertEqual(test_layer.output_shape, [(None, 32), (None, 32)]) - - def testTopologicalAttributesMultiInputLayer(self): - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - assert len(inputs) == 2 - return inputs[0] + inputs[1] - - a = network_layers.Input(shape=(32,)) - b = network_layers.Input(shape=(32,)) - test_layer = AddLayer() - y = test_layer([a, b]) # pylint: disable=not-callable - - self.assertEqual(test_layer.input, [a, b]) - self.assertEqual(test_layer.output, y) - self.assertEqual(test_layer.input_shape, [(None, 32), (None, 32)]) - self.assertEqual(test_layer.output_shape, (None, 32)) - - -class NetworkTest(test.TestCase): - - def testBasicNetwork(self): - # minimum viable network - x = network_layers.Input(shape=(32,)) - dense = core_layers.Dense(2) - y = dense(x) - network = network_layers.GraphNetwork(x, y, name='dense_network') - - # test basic attributes - self.assertEqual(network.name, 'dense_network') - self.assertEqual(len(network.layers), 2) # InputLayer + Dense - self.assertEqual(network.layers[1], dense) - self.assertEqual(network.weights, dense.weights) - self.assertEqual(network.trainable_weights, dense.trainable_weights) - self.assertEqual(network.non_trainable_weights, dense.non_trainable_weights) - - # test callability on Input - x_2 = network_layers.Input(shape=(32,)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 2]) - - # test callability on regular tensor - x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 2]) - - # test network `trainable` attribute - network.trainable = False - self.assertEqual(network.weights, dense.weights) - self.assertEqual(network.trainable_weights, []) - self.assertEqual(network.non_trainable_weights, - dense.trainable_weights + dense.non_trainable_weights) - - def test_node_construction(self): - # test graph topology construction basics - a = network_layers.Input(shape=(32,), name='input_a') - b = network_layers.Input(shape=(32,), name='input_b') - - self.assertEqual(a.get_shape().as_list(), [None, 32]) - a_layer, a_node_index, a_tensor_index = a._keras_history - b_layer, _, _ = b._keras_history - self.assertEqual(len(a_layer._inbound_nodes), 1) - self.assertEqual(a_tensor_index, 0) - node = a_layer._inbound_nodes[a_node_index] - self.assertEqual(node.outbound_layer, a_layer) - - self.assertEqual(node.inbound_layers, []) - self.assertEqual(node.input_tensors, [a]) - self.assertEqual(node.input_shapes, [(None, 32)]) - self.assertEqual(node.output_tensors, [a]) - self.assertEqual(node.output_shapes, [(None, 32)]) - - dense = core_layers.Dense(16, name='dense_1') - dense(a) - dense(b) - - self.assertEqual(len(dense._inbound_nodes), 2) - self.assertEqual(len(dense._outbound_nodes), 0) - self.assertEqual(dense._inbound_nodes[0].inbound_layers, [a_layer]) - self.assertEqual(dense._inbound_nodes[0].outbound_layer, dense) - self.assertEqual(dense._inbound_nodes[1].inbound_layers, [b_layer]) - self.assertEqual(dense._inbound_nodes[1].outbound_layer, dense) - self.assertEqual(dense._inbound_nodes[0].input_tensors, [a]) - self.assertEqual(dense._inbound_nodes[1].input_tensors, [b]) - - # Test config - config_0 = dense._inbound_nodes[0].get_config() - self.assertEqual(config_0['outbound_layer'], dense.name) - - def testMultiInputNetwork(self): - a = network_layers.Input(shape=(32,), name='input_a') - b = network_layers.Input(shape=(32,), name='input_b') - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - assert len(inputs) == 2 - return inputs[0] + inputs[1] - - c = AddLayer()([a, b]) # pylint: disable=not-callable - network = network_layers.GraphNetwork([a, b], c) - self.assertEqual(len(network.layers), 3) # 2 * InputLayer + AddLayer - - # Test callability. - a2 = network_layers.Input(shape=(32,)) - b2 = network_layers.Input(shape=(32,)) - c2 = network([a2, b2]) - self.assertEqual(c2.get_shape().as_list(), [None, 32]) - - def testMultiOutputNetwork(self): - x = network_layers.Input(shape=(32,)) - y1 = core_layers.Dense(2)(x) - y2 = core_layers.Dense(3)(x) - network = network_layers.GraphNetwork(x, [y1, y2]) - - self.assertEqual(len(network.layers), 3) # InputLayer + 2 * Dense - - # Test callability. - x2 = network_layers.Input(shape=(32,)) - outputs = network(x2) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) - - def testMultiInputMultiOutputNetworkSharedLayer(self): - a = network_layers.Input(shape=(32,), name='input_a') - b = network_layers.Input(shape=(32,), name='input_b') - - dense = core_layers.Dense(2) - - y1 = dense(a) - y2 = dense(b) - network = network_layers.GraphNetwork([a, b], [y1, y2]) - self.assertEqual(len(network.layers), 3) # 2 * InputLayer + Dense - - # Test callability. - a2 = network_layers.Input(shape=(32,)) - b2 = network_layers.Input(shape=(32,)) - outputs = network([a2, b2]) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 2]) - - def testCrossDataFlows(self): - # Test the ability to have multi-output layers with outputs that get routed - # to separate layers - - class PowersLayer(base_layers.Layer): - - def call(self, inputs): - return [inputs**2, inputs**3] - - x = network_layers.Input(shape=(32,)) - p1, p2 = PowersLayer()(x) # pylint: disable=not-callable - y1 = core_layers.Dense(2)(p1) - y2 = core_layers.Dense(3)(p2) - network = network_layers.GraphNetwork(x, [y1, y2]) - - self.assertEqual(len(network.layers), 4) # InputLayer + 2 * Dense + PLayer - - # Test callability. - x2 = network_layers.Input(shape=(32,)) - outputs = network(x2) - - self.assertEqual(type(outputs), list) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].get_shape().as_list(), [None, 2]) - self.assertEqual(outputs[1].get_shape().as_list(), [None, 3]) - - def testNetworkAttributes(self): - x = network_layers.Input(shape=(32,)) - layer = core_layers.Dense(2, kernel_regularizer=lambda x: 0.01 * (x**2)) - z = layer(x) - dense = core_layers.Dense(2, name='dense') - dense.add_update(state_ops.assign_add(layer.kernel, layer.kernel * 2.)) - y = dense(z) - net = network_layers.GraphNetwork(x, y) - - # losses - self.assertEqual(len(net.losses), 1) - - # updates - self.assertEqual(len(net.updates), 1) - - # get_layer - self.assertEqual(net.get_layer('dense'), dense) - self.assertEqual(net.get_layer(index=2), dense) - with self.assertRaises(ValueError): - net.get_layer('dense_unknown') - with self.assertRaises(ValueError): - net.get_layer() - with self.assertRaises(ValueError): - net.get_layer(index=4) - - # input, output - self.assertEqual(net.input, x) - self.assertEqual(net.output, y) - - # input_shape, output_shape - self.assertEqual(net.input_shape, (None, 32)) - self.assertEqual(net.output_shape, (None, 2)) - - # get_*_at - self.assertEqual(net.get_input_at(0), x) - self.assertEqual(net.get_output_at(0), y) - - # compute_output_shape - self.assertEqual(net.compute_output_shape((3, 32)).as_list(), [3, 2]) - - def testInvalidNetworks(self): - # redundant inputs - x = network_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - network_layers.GraphNetwork([x, x], y) - - # inputs that don't come from Input - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - network_layers.GraphNetwork(x, y) - - # inputs that don't come from Input but have a layer history - x = network_layers.Input(shape=(32,)) - x = core_layers.Dense(32)(x) - y = core_layers.Dense(2)(x) - with self.assertRaises(ValueError): - network_layers.GraphNetwork(x, y) - - # outputs that don't come from layers - x = network_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x) - y = 2 * y - with self.assertRaises(ValueError): - network_layers.GraphNetwork(x, y) - - # disconnected graphs - x1 = network_layers.Input(shape=(32,)) - x2 = network_layers.Input(shape=(32,)) - y = core_layers.Dense(2)(x1) - with self.assertRaises(ValueError): - network_layers.GraphNetwork(x2, y) - - # redundant layer names - x = network_layers.Input(shape=(32,)) - z = core_layers.Dense(2, name='dense')(x) - y = core_layers.Dense(2, name='dense')(z) - with self.assertRaises(ValueError): - network_layers.GraphNetwork(x, y) - - def testInputTensorWrapping(self): - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - x = network_layers.Input(tensor=x) - y = core_layers.Dense(2)(x) - network_layers.GraphNetwork(x, y) - - def testExplicitBatchSize(self): - x = network_layers.Input(shape=(32,), batch_size=3) - y = core_layers.Dense(2)(x) - self.assertEqual(y.get_shape().as_list(), [3, 2]) - - def testNetworkRecursion(self): - # test the ability of networks to be used as layers inside networks. - a = network_layers.Input(shape=(32,)) - b = core_layers.Dense(2)(a) - net = network_layers.GraphNetwork(a, b) - - c = network_layers.Input(shape=(32,)) - d = net(c) - - recursive_net = network_layers.GraphNetwork(c, d) - self.assertEqual(len(recursive_net.layers), 2) - self.assertEqual(recursive_net.layers[1], net) - self.assertEqual(len(recursive_net.weights), 2) - - # test callability - x = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y = recursive_net(x) - self.assertEqual(y.get_shape().as_list(), [None, 2]) - - def testSparseInput(self): - - class SparseSoftmax(base_layers.Layer): - - def call(self, inputs): - return sparse_ops.sparse_softmax(inputs) - - x = network_layers.Input(shape=(32,), sparse=True) - y = SparseSoftmax()(x) # pylint: disable=not-callable - network = network_layers.GraphNetwork(x, y) - - self.assertEqual(len(network.layers), 2) - self.assertEqual(network.layers[0].sparse, True) - - @test_util.run_in_graph_and_eager_modes() - def testMaskingSingleInput(self): - - class MaskedLayer(base_layers.Layer): - - def call(self, inputs, mask=None): - if mask is not None: - return inputs * mask - return inputs - - def compute_mask(self, inputs, mask=None): - return array_ops.ones_like(inputs) - - if context.in_graph_mode(): - x = network_layers.Input(shape=(32,)) - y = MaskedLayer()(x) # pylint: disable=not-callable - network = network_layers.GraphNetwork(x, y) - - # test callability on Input - x_2 = network_layers.Input(shape=(32,)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - - # test callability on regular tensor - x_2 = array_ops.placeholder(dtype='float32', shape=(None, 32)) - y_2 = network(x_2) - self.assertEqual(y_2.get_shape().as_list(), [None, 32]) - else: - a = constant_op.constant([2] * 32) - mask = constant_op.constant([0, 1] * 16) - a._keras_mask = mask - b = MaskedLayer().apply(a) - self.assertTrue(hasattr(b, '_keras_mask')) - self.assertAllEqual(self.evaluate(array_ops.ones_like(mask)), - self.evaluate(getattr(b, '_keras_mask'))) - self.assertAllEqual(self.evaluate(a * mask), self.evaluate(b)) - - -class DeferredModeTest(test.TestCase): - - def testDeferredTensorAttributes(self): - x = base_layers._DeferredTensor(shape=(None, 2), dtype='float32', name='x') - self.assertEqual(str(x), - 'DeferredTensor(\'x\', shape=(?, 2), dtype=float32)') - self.assertEqual(repr(x), - '<_DeferredTensor \'x\' shape=(?, 2) dtype=float32>') - - @test_util.run_in_graph_and_eager_modes() - def testSimpleNetworkBuilding(self): - inputs = network_layers.Input(shape=(32,)) - if context.in_eager_mode(): - self.assertIsInstance(inputs, base_layers._DeferredTensor) - self.assertEqual(inputs.dtype.name, 'float32') - self.assertEqual(inputs.shape.as_list(), [None, 32]) - - x = core_layers.Dense(2)(inputs) - if context.in_eager_mode(): - self.assertIsInstance(x, base_layers._DeferredTensor) - self.assertEqual(x.dtype.name, 'float32') - self.assertEqual(x.shape.as_list(), [None, 2]) - - outputs = core_layers.Dense(4)(x) - network = network_layers.GraphNetwork(inputs, outputs) - self.assertIsInstance(network, network_layers.GraphNetwork) - - if context.in_eager_mode(): - # It should be possible to call such a network on EagerTensors. - inputs = constant_op.constant( - np.random.random((10, 32)).astype('float32')) - outputs = network(inputs) - self.assertEqual(outputs.shape.as_list(), [10, 4]) - - @test_util.run_in_graph_and_eager_modes() - def testMultiIONetworkbuilding(self): - input_a = network_layers.Input(shape=(32,)) - input_b = network_layers.Input(shape=(16,)) - a = core_layers.Dense(16)(input_a) - - class AddLayer(base_layers.Layer): - - def call(self, inputs): - return inputs[0] + inputs[1] - - def compute_output_shape(self, input_shape): - return input_shape[0] - - c = AddLayer()([a, input_b]) # pylint: disable=not-callable - c = core_layers.Dense(2)(c) - - network = network_layers.GraphNetwork([input_a, input_b], [a, c]) - if context.in_eager_mode(): - a_val = constant_op.constant( - np.random.random((10, 32)).astype('float32')) - b_val = constant_op.constant( - np.random.random((10, 16)).astype('float32')) - outputs = network([a_val, b_val]) - self.assertEqual(len(outputs), 2) - self.assertEqual(outputs[0].shape.as_list(), [10, 16]) - self.assertEqual(outputs[1].shape.as_list(), [10, 2]) - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/python/layers/normalization.py b/tensorflow/python/layers/normalization.py index 656d566ab5497016244d717b3e85bee93f1d9796..d83292b80963d942023b5d086a089af53008efe0 100644 --- a/tensorflow/python/layers/normalization.py +++ b/tensorflow/python/layers/normalization.py @@ -94,8 +94,8 @@ class BatchNormalization(base.Layer): and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. trainable: Boolean, if `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, @@ -493,6 +493,7 @@ class BatchNormalization(base.Layer): return (r, d, new_mean, new_variance) def call(self, inputs, training=False): + in_eager_mode = context.in_eager_mode() if self.virtual_batch_size is not None: # Virtual batches (aka ghost batches) can be simulated by reshaping the # Tensor and reusing the existing batch norm implementation @@ -595,6 +596,9 @@ class BatchNormalization(base.Layer): axis=1, keep_dims=True) def _do_update(var, value): + if in_eager_mode and not self.trainable: + return + return moving_averages.assign_moving_average( var, value, self.momentum, zero_debias=False) @@ -725,8 +729,8 @@ def batch_normalization(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `momentum` is still applied to get the means and variances for inference. - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. virtual_batch_size: An `int`. By default, `virtual_batch_size` is `None`, which means batch normalization is performed across the whole batch. When `virtual_batch_size` is not `None`, instead perform "Ghost Batch diff --git a/tensorflow/python/layers/utils.py b/tensorflow/python/layers/utils.py index 1bbf4e6dffd3415ba246e26cd92923df8116edab..3b156c36a2ff35fb9e05af1406d7b3f6cf883394 100644 --- a/tensorflow/python/layers/utils.py +++ b/tensorflow/python/layers/utils.py @@ -20,9 +20,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.ops import variables from tensorflow.python.ops import control_flow_ops from tensorflow.python.framework import ops +from tensorflow.python.framework import smart_cond as smart_module from tensorflow.python.framework import tensor_util from tensorflow.python.util import nest @@ -178,67 +180,56 @@ def deconv_output_length(input_length, filter_size, padding, stride): return input_length -def smart_cond(pred, fn1, fn2, name=None): - """Return either `fn1()` or `fn2()` based on the boolean predicate `pred`. +def smart_cond(pred, true_fn=None, false_fn=None, name=None): + """Return either `true_fn()` if predicate `pred` is true else `false_fn()`. - If `pred` is a bool or has a constant value, we return either `fn1()` - or `fn2()`, otherwise we use `tf.cond` to dynamically route to both. + If `pred` is a bool or has a constant value, we return either `true_fn()` + or `false_fn()`, otherwise we use `tf.cond` to dynamically route to both. Arguments: - pred: A scalar determining whether to return the result of `fn1` or `fn2`. - fn1: The callable to be performed if pred is true. - fn2: The callable to be performed if pred is false. + pred: A scalar determining whether to return the result of `true_fn` or + `false_fn`. + true_fn: The callable to be performed if pred is true. + false_fn: The callable to be performed if pred is false. name: Optional name prefix when using `tf.cond`. Returns: - Tensors returned by the call to either `fn1` or `fn2`. + Tensors returned by the call to either `true_fn` or `false_fn`. Raises: - TypeError: If `fn1` or `fn2` is not callable. + TypeError: If `true_fn` or `false_fn` is not callable. """ - if not callable(fn1): - raise TypeError('`fn1` must be callable.') - if not callable(fn2): - raise TypeError('`fn2` must be callable.') - - pred_value = constant_value(pred) - if pred_value is not None: - if pred_value: - return fn1() - else: - return fn2() - else: - return control_flow_ops.cond(pred, true_fn=fn1, false_fn=fn2, name=name) + if isinstance(pred, variables.Variable): + return control_flow_ops.cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) + return smart_module.smart_cond( + pred, true_fn=true_fn, false_fn=false_fn, name=name) def constant_value(pred): """Return the bool value for `pred`, or None if `pred` had a dynamic value. - Arguments: - pred: A scalar, either a Python bool or a TensorFlow boolean variable - or tensor, or the Python integer 1 or 0. + Arguments: + pred: A scalar, either a Python bool or a TensorFlow boolean variable + or tensor, or the Python integer 1 or 0. - Returns: - True or False if `pred` has a constant boolean value, None otherwise. + Returns: + True or False if `pred` has a constant boolean value, None otherwise. - Raises: - TypeError: If `pred` is not a Variable, Tensor or bool. - """ + Raises: + TypeError: If `pred` is not a Variable, Tensor or bool, or Python + interger 1 or 0. + """ # Allow integer booleans. - if pred == 0: - pred = False - elif pred == 1: - pred = True - - if isinstance(pred, bool): - pred_value = pred - elif isinstance(pred, variables.Variable): - pred_value = None - elif isinstance(pred, ops.Tensor): - pred_value = tensor_util.constant_value(pred) - else: - raise TypeError('`pred` must be a Tensor, a Variable, or a Python bool.') - return pred_value + if isinstance(pred, int): + if pred == 1: + pred = True + elif pred == 0: + pred = False + + if isinstance(pred, variables.Variable): + return None + return smart_module.smart_constant_value(pred) def object_list_uid(object_list): diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py index ad409ad7e5a152bbc4312e1d16f324bb8be71c33..96f5f81c1f04b7d64ddbd6fd461348c6986d9ff6 100644 --- a/tensorflow/python/ops/array_ops.py +++ b/tensorflow/python/ops/array_ops.py @@ -134,7 +134,10 @@ def identity(input, name=None): # pylint: disable=redefined-builtin input = ops.convert_to_tensor(input) in_device = input.device # TODO(ashankar): Does 'identity' need to invoke execution callbacks? - if context.context().device_name != in_device: + context_device = context.context().device_name + if not context_device: + context_device = "/job:localhost/replica:0/task:0/device:CPU:0" + if context_device != in_device: return input._copy() # pylint: disable=protected-access return input @@ -386,6 +389,13 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): Returns: A `Tensor` of type `out_type`. Defaults to `tf.int32`. """ + if context.in_eager_mode() and not isinstance( + input, (sparse_tensor.SparseTensor, + sparse_tensor.SparseTensorValue)): + size_ = 1 + for dim in ops.convert_to_tensor(input)._shape_tuple(): # pylint: disable=protected-access + size_ *= dim + return size_ with ops.name_scope(name, "Size", [input]) as name: if isinstance(input, (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)): @@ -394,8 +404,11 @@ def size_internal(input, name=None, optimize=True, out_type=dtypes.int32): else: input_tensor = ops.convert_to_tensor(input) input_shape = input_tensor.get_shape() - if optimize and input_shape.is_fully_defined(): - return constant(input_shape.num_elements(), out_type, name=name) + if optimize: + if input_shape.is_fully_defined(): + return constant(input_shape.num_elements(), out_type, name=name) + if input_shape.dims and any(dim == 0 for dim in input_shape.dims): + return constant(0, out_type, name=name) return gen_array_ops.size(input, name=name, out_type=out_type) @@ -605,7 +618,7 @@ def slice(input_, begin, size, name=None): Note that @{tf.Tensor.__getitem__} is typically a more pythonic way to perform slices, as it allows you to write `foo[3:7, :-2]` instead of - `tf.slice([3, 0], [4, foo.get_shape()[1]-2])`. + `tf.slice(foo, [3, 0], [4, foo.get_shape()[1]-2])`. `begin` is zero-based; `size` is one-based. If `size[i]` is -1, all remaining elements in dimension i are included in the @@ -1312,6 +1325,18 @@ def unique(x, out_idx=dtypes.int32, name=None): unique.__doc__ = gen_array_ops._unique.__doc__ +@tf_export("unique_with_counts") +def unique_with_counts(x, out_idx=dtypes.int32, name=None): + # TODO(yongtang): switch to v2 once API deprecation + # period (3 weeks) pass. + # TODO(yongtang): The documentation should also + # be updated when switch to v2. + return gen_array_ops._unique_with_counts(x, out_idx, name) + + +unique_with_counts.__doc__ = gen_array_ops._unique_with_counts.__doc__ + + @tf_export("split") def split(value, num_or_size_splits, axis=0, num=None, name="split"): """Splits a tensor into sub tensors. @@ -1390,6 +1415,14 @@ def transpose(a, perm=None, name="transpose", conjugate=False): `a.dtype` is either `complex64` or `complex128` then the values of `a` are conjugated and transposed. + @compatibility(numpy) + In `numpy` transposes are memory-efficient constant time operations as they + simply return a new view of the same data with adjusted `strides`. + + TensorFlow does not support strides, so `transpose` returns a new tensor with + the items permuted. + @end_compatibility + For example: ```python @@ -1490,6 +1523,14 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False): tf.matmul(matrix, tf.matrix_transpose(b)) ``` + @compatibility(numpy) + In `numpy` transposes are memory-efficient constant time operations as they + simply return a new view of the same data with adjusted `strides`. + + TensorFlow does not support strides, `matrix_transposes` return a new tensor + with the items permuted. + @end_compatibility + Args: a: A `Tensor` with `rank >= 2`. name: A name for the operation (optional). diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 0fd6e29a49c8e4e31e244bfbbfca525d72e4d811..64567ac54ae43acf6f8b674c46525db7a6c4fab7 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -334,9 +334,9 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None): @compatibility{eager} returns None Raises: - InvalidArgumentError if the check can be performed immediately and - `x == y` is False. The check can be performed immediately during - eager execution or if `x` and `y` are statically known. + InvalidArgumentError: if the check can be performed immediately and + `x == y` is False. The check can be performed immediately during eager + execution or if `x` and `y` are statically known. """ message = message or '' with ops.name_scope(name, 'assert_equal', [x, y, data]): diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py index e4ce2ab28a15f82e80194ab17ef939411982076a..b9a93c3bedfff1f398e3b42cedf02a2f0a3ddd5c 100644 --- a/tensorflow/python/ops/confusion_matrix.py +++ b/tensorflow/python/ops/confusion_matrix.py @@ -99,19 +99,16 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, name=None, weights=None): """Computes the confusion matrix from predictions and labels. - Calculate the Confusion Matrix for a pair of prediction and - label 1-D int arrays. - The matrix columns represent the prediction labels and the rows represent the real labels. The confusion matrix is always a 2-D array of shape `[n, n]`, where `n` is the number of valid labels for a given classification task. Both prediction and labels must be 1-D arrays of the same shape in order for this function to work. - If `num_classes` is None, then `num_classes` will be set to the one plus - the maximum value in either predictions or labels. - Class labels are expected to start at 0. E.g., if `num_classes` was - three, then the possible labels would be `[0, 1, 2]`. + If `num_classes` is `None`, then `num_classes` will be set to one plus the + maximum value in either predictions or labels. Class labels are expected to + start at 0. For example, if `num_classes` is 3, then the possible labels + would be `[0, 1, 2]`. If `weights` is not `None`, then each prediction contributes its corresponding weight to the total value of the confusion matrix cell. @@ -141,8 +138,9 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, weights: An optional `Tensor` whose shape matches `predictions`. Returns: - A k X k matrix representing the confusion matrix, where k is the number of - possible labels in the classification task. + A `Tensor` of type `dtype` with shape `[n, n]` representing the confusion + matrix, where `n` is the number of possible labels in the classification + task. Raises: ValueError: If both predictions and labels are not 1-D vectors and have @@ -188,7 +186,7 @@ def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32, weights = math_ops.cast(weights, dtype) shape = array_ops.stack([num_classes, num_classes]) - indices = array_ops.transpose(array_ops.stack([labels, predictions])) + indices = array_ops.stack([labels, predictions], axis=1) values = (array_ops.ones_like(predictions, dtype) if weights is None else weights) cm_sparse = sparse_tensor.SparseTensor( diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index c33f3512893a413dd4c5b9a1fd87c9bb498627f9..0815527c9644bcbc01f91ad01e716061963513bd 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -44,6 +44,7 @@ See the @{$python/control_flow_ops} guide. @@add_check_numerics_ops @@Assert @@Print +@@timestamp """ # pylint: disable=g-bad-name from __future__ import absolute_import @@ -177,8 +178,6 @@ def Assert(condition, data, summarize=None, name=None): condition, data, summarize, name="Assert") guarded_assert = cond(condition, no_op, true_assert, name="AssertGuard") - if context.in_eager_mode(): - return return guarded_assert.op @@ -1717,8 +1716,15 @@ class CondContext(ControlFlowContext): self._pivot = g.as_graph_element( ops.prepend_name_scope(context_def.pivot_name, import_scope)) self._branch = context_def.branch - super(CondContext, self).__init__( - values_def=context_def.values_def, import_scope=import_scope) + super(CondContext, self).__init__(values_def=context_def.values_def, + import_scope=import_scope) + # The predicate and pivot ops appear in self._values, but don't have self + # set as their control context. The __init__ call above will set self for + # all values, so manually override the predicate and pivot contexts here. + # pylint: disable=protected-access + self._pred.op._set_control_flow_context(self.outer_context) + self._pivot.op._set_control_flow_context(self.outer_context) + # pylint: enable=protected-access @property def pred(self): @@ -1766,13 +1772,9 @@ class CondContext(ControlFlowContext): context_def.branch = self._branch context_def.values_def.MergeFrom(super(CondContext, self)._to_values_def( export_scope)) - # TODO(b/72868227): enable this once the corresponding control_flow.proto - # changes have been checked in (they aren't checked in and this is - # disabled for now to ensure forwards compatibility). - if False: # pylint: disable=using-constant-test - for nested in self._nested_contexts: - nested_def = context_def.nested_contexts.add() - nested.to_control_flow_context_def(nested_def) + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -1784,14 +1786,10 @@ class CondContext(ControlFlowContext): ret = CondContext(context_def=context_def, import_scope=import_scope) - # TODO(b/72868227): remove "if hasattr(...)" once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is here for now to ensure forwards compatibility). - if hasattr(context_def, "nested_contexts"): - ret.Enter() - for nested_def in context_def.nested_contexts: - from_control_flow_context_def(nested_def) - ret.Exit() + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def) + ret.Exit() return ret def to_control_flow_context_def(self, context_def, export_scope=None): @@ -1835,8 +1833,6 @@ class CondContext(ControlFlowContext): # pylint: disable=protected-access op._add_control_input(self._pivot.op) # pylint: enable=protected-access - for x in op.outputs: - self._values.add(x.name) else: for index in range(len(op.inputs)): x = op.inputs[index] @@ -1847,13 +1843,20 @@ class CondContext(ControlFlowContext): # pylint: enable=protected-access # Remove any external control dependency on this op. self._RemoveExternalControlEdges(op) - for x in op.outputs: - self._values.add(x.name) # pylint: disable=protected-access if op.graph._is_function(op.type) or op.type == "SymbolicGradient": op._add_control_input(self._pivot.op) # pylint: enable=protected-access + # Mark op's outputs as seen by this context and any outer contexts. + output_names = [x.name for x in op.outputs] + ctxt = self + while ctxt is not None: + # pylint: disable=protected-access + ctxt._values.update(output_names) + ctxt = ctxt._outer_context + # pylint: enable=protected-access + if self._outer_context or not util.IsLoopExit(op): op.graph.prevent_fetching(op) @@ -2104,10 +2107,7 @@ def cond(pred, # Only add non-nested conds to the collection. Any nested control flow will # be encapsulated in the root context. assert context_t.outer_context == context_f.outer_context - # TODO(b/72868227): remove "if True..." once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if True or context_t.outer_context is None: + if context_t.outer_context is None: ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_t) ops.add_to_collection(ops.GraphKeys.COND_CONTEXT, context_f) @@ -2330,13 +2330,9 @@ class WhileContext(ControlFlowContext): context_def.values_def.MergeFrom( super(WhileContext, self)._to_values_def( export_scope=export_scope)) - # TODO(b/72868227): remove "if True..." once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if False: # pylint: disable=using-constant-test - for nested in self._nested_contexts: - nested_def = context_def.nested_contexts.add() - nested.to_control_flow_context_def(nested_def) + for nested in self._nested_contexts: + nested_def = context_def.nested_contexts.add() + nested.to_control_flow_context_def(nested_def) return context_def else: @@ -2358,14 +2354,10 @@ class WhileContext(ControlFlowContext): """ ret = WhileContext(context_def=context_def, import_scope=import_scope) - # TODO(b/72868227): remove "if hasattr(...)" once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if hasattr(context_def, "nested_contexts"): - ret.Enter() - for nested_def in context_def.nested_contexts: - from_control_flow_context_def(nested_def, import_scope=import_scope) - ret.Exit() + ret.Enter() + for nested_def in context_def.nested_contexts: + from_control_flow_context_def(nested_def, import_scope=import_scope) + ret.Exit() return ret def GetWhileContext(self): @@ -3120,6 +3112,43 @@ def while_loop(cond, shape_invariants=[i0.get_shape(), tf.TensorShape([None, 2])]) ``` + Example which demonstrates non-strict semantics: In the following + example, the final value of the counter `i` does not depend on `x`. So + the `while_loop` can increment the counter parallel to updates of `x`. + However, because the loop counter at one loop iteration depends + on the value at the previous iteration, the loop counter itself cannot + be incremented in parallel. Hence if we just want the final value of the + counter (which we print on the line `print(sess.run(i))`), then + `x` will never be incremented, but the counter will be updated on a + single thread. Conversely, if we want the value of the output (which we + print on the line `print(sess.run(out).shape)`), then the counter may be + incremented on its own thread, while `x` can be incremented in + parallel on a separate thread. In the extreme case, it is conceivable + that the thread incrementing the counter runs until completion before + `x` is incremented even a single time. The only thing that can never + happen is that the thread updating `x` can never get ahead of the + counter thread because the thread incrementing `x` depends on the value + of the counter. + ```python + import tensorflow as tf + + n = 10000 + x = tf.constant(list(range(n))) + c = lambda i, x: i < n + b = lambda i, x: (tf.Print(i + 1, [i]), tf.Print(x + 1, [i], "x:")) + i, out = tf.while_loop(c, b, (0, x)) + with tf.Session() as sess: + print(sess.run(i)) # prints [0] ... [9999] + + # The following line may increment the counter and x in parallel. + # The counter thread may get ahead of the other thread, but not the + # other way around. So you may see things like + # [9996] x:[9987] + # meaning that the counter thread is on iteration 9996, + # while the other thread is on iteration 9987 + print(sess.run(out).shape) + ``` + """ with ops.name_scope(name, "while", loop_vars): if not loop_vars: @@ -3173,10 +3202,7 @@ def while_loop(cond, swap_memory=swap_memory) # Only add non-nested loops to the collection. Any nested control flow will # be encapsulated in the root context. - # TODO(b/72868227): enable condition once the corresponding - # control_flow.proto changes have been checked in (they aren't checked in - # and this is disabled for now to ensure forwards compatibility). - if True or loop_context.outer_context is None: + if loop_context.outer_context is None: ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) if maximum_iterations is not None: @@ -3378,7 +3404,12 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined if context.in_eager_mode(): return tensors with ops.name_scope(name, "tuple", tensors) as name: - gating_ops = [t.op for t in tensors if t is not None] + tensors = [t if (isinstance(t, ops.Operation) + or tensor_util.is_tensor(t) + or t is None) + else ops.convert_to_tensor(t) for t in tensors] + gating_ops = [t if isinstance(t, ops.Operation) else t.op for t in tensors + if t is not None] if control_inputs: for c in control_inputs: if isinstance(c, ops.Tensor): @@ -3394,8 +3425,11 @@ def tuple(tensors, name=None, control_inputs=None): # pylint: disable=redefined gate = group(*gating_ops) tpl = [] for t in tensors: - if t is not None: + if tensor_util.is_tensor(t): tpl.append(with_dependencies([gate], t)) + elif isinstance(t, ops.Operation): + with ops.control_dependencies([gate]): + tpl.append(group(t)) else: tpl.append(None) return tpl diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py index 95e45bff066d4b2653e5de67684a6277006345f2..03ed537cfcf27151a0200d7a17f63b1a2bc7ba1a 100644 --- a/tensorflow/python/ops/data_flow_ops.py +++ b/tensorflow/python/ops/data_flow_ops.py @@ -474,7 +474,7 @@ class QueueBase(object): name: A name for the operation (optional). Returns: - The tuple of concatenated tensors that was dequeued. + The list of concatenated tensors that was dequeued. """ if name is None: name = "%s_DequeueMany" % self._name diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py index 553e5db8d81f7b687b826368f2663f874441bdb9..68aaf3815e7e2b21c9550562aa49195569c8ea43 100644 --- a/tensorflow/python/ops/distributions/bernoulli.py +++ b/tensorflow/python/ops/distributions/bernoulli.py @@ -22,7 +22,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import random_ops @@ -137,21 +136,12 @@ class Bernoulli(distribution.Distribution): return (array_ops.ones_like(event) * logits, array_ops.ones_like(logits) * event) - # First check static shape. - if (event.get_shape().is_fully_defined() and - logits.get_shape().is_fully_defined()): - if event.get_shape() != logits.get_shape(): - logits, event = _broadcast(logits, event) - else: - logits, event = control_flow_ops.cond( - distribution_util.same_dynamic_shape(logits, event), - lambda: (logits, event), - lambda: _broadcast(logits, event)) + if not (event.get_shape().is_fully_defined() and + logits.get_shape().is_fully_defined() and + event.get_shape() == logits.get_shape()): + logits, event = _broadcast(logits, event) return -nn.sigmoid_cross_entropy_with_logits(labels=event, logits=logits) - def _prob(self, event): - return math_ops.exp(self._log_prob(event)) - def _entropy(self): return (-self.logits * (math_ops.sigmoid(self.logits) - 1) + nn.softplus(-self.logits)) diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py index be4ef550dddc4f393f3d81730be59fc0def47500..469bcadb8ea3a0ec2a85d3a72c0ca5ba08796856 100644 --- a/tensorflow/python/ops/distributions/beta.py +++ b/tensorflow/python/ops/distributions/beta.py @@ -304,11 +304,10 @@ class Beta(distribution.Distribution): if not self.validate_args: return x return control_flow_ops.with_dependencies([ - check_ops.assert_positive( - x, - message="sample must be positive"), + check_ops.assert_positive(x, message="sample must be positive"), check_ops.assert_less( - x, array_ops.ones([], self.dtype), + x, + array_ops.ones([], self.dtype), message="sample must be less than `1`."), ], x) diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py index 26b5c5aef98fc11b07a8c8357e7ec37819587da9..4ae67a009b0a4052f6e23e2e42262bb7c42f1c14 100644 --- a/tensorflow/python/ops/distributions/multinomial.py +++ b/tensorflow/python/ops/distributions/multinomial.py @@ -238,7 +238,7 @@ class Multinomial(distribution.Distribution): n_draws = math_ops.cast(self.total_count, dtype=dtypes.int32) k = self.event_shape_tensor()[0] - # boardcast the total_count and logits to same shape + # broadcast the total_count and logits to same shape n_draws = array_ops.ones_like( self.logits[..., 0], dtype=n_draws.dtype) * n_draws logits = array_ops.ones_like( diff --git a/tensorflow/python/ops/distributions/special_math.py b/tensorflow/python/ops/distributions/special_math.py index bed4cbb2c1a43b6952861f4fab82957229e23c9c..1d605c5dfcca9b709a9178ccbe56619f6a92f869 100644 --- a/tensorflow/python/ops/distributions/special_math.py +++ b/tensorflow/python/ops/distributions/special_math.py @@ -213,7 +213,7 @@ def _ndtri(p): # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different - # arrays based on wether p < exp(-32). + # arrays based on whether p < exp(-32). z = math_ops.sqrt(-2. * math_ops.log(sanitized_mcp)) first_term = z - math_ops.log(z) / z second_term_small_p = (_create_polynomial(1. / z, p2) diff --git a/tensorflow/python/ops/distributions/util.py b/tensorflow/python/ops/distributions/util.py index 0a3000ef5ca0decf8aba641e704406b0cf8780af..0fe6aa30f945dc7682a53fa6495823288cf111b7 100644 --- a/tensorflow/python/ops/distributions/util.py +++ b/tensorflow/python/ops/distributions/util.py @@ -1060,9 +1060,7 @@ def reduce_weighted_logsumexp( wx_over_max_absw_x = ( math_ops.sign(w) * math_ops.exp(log_absw_x - max_log_absw_x)) sum_wx_over_max_absw_x = math_ops.reduce_sum( - wx_over_max_absw_x, - axis=axis, - keepdims=keep_dims) + wx_over_max_absw_x, axis=axis, keepdims=keep_dims) if not keep_dims: max_log_absw_x = array_ops.squeeze(max_log_absw_x, axis) sgn = math_ops.sign(sum_wx_over_max_absw_x) @@ -1180,8 +1178,7 @@ def process_quadrature_grid_and_probs( grid = ops.convert_to_tensor(grid, name="grid", dtype=dtype) probs = ops.convert_to_tensor(probs, name="unnormalized_probs", dtype=dtype) - probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, - name="probs") + probs /= linalg_ops.norm(probs, ord=1, axis=-1, keepdims=True, name="probs") def _static_event_size(x): """Returns the static size of a specific dimension or `None`.""" diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py index 9f06c0ee1f403708a0480509cbede579fa6811ee..1418c0b10fb60601e7c3024891b89aadb53e6873 100644 --- a/tensorflow/python/ops/gradients_impl.py +++ b/tensorflow/python/ops/gradients_impl.py @@ -494,7 +494,7 @@ def gradients(ys, list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope: ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y") xs = [ - x.handle if isinstance(x, resource_variable_ops.ResourceVariable) else x + x.handle if resource_variable_ops.is_resource_variable(x) else x for x in xs ] xs = ops.internal_convert_n_to_tensor_or_indexed_slices( diff --git a/tensorflow/python/ops/hidden_ops.txt b/tensorflow/python/ops/hidden_ops.txt index f6ef6f3f3da4389a16a84fa0b3570d3cd1262472..9b8172bf2639cca0efb663ff4075b36d6f4f2245 100644 --- a/tensorflow/python/ops/hidden_ops.txt +++ b/tensorflow/python/ops/hidden_ops.txt @@ -32,6 +32,8 @@ TileGrad # Exported through array_grad instead of array_ops. ZerosLike # TODO(josh11b): Use this instead of the Python version. Unique UniqueV2 +UniqueWithCounts +UniqueWithCountsV2 Unpack # candidate_sampling_ops diff --git a/tensorflow/python/ops/image_grad.py b/tensorflow/python/ops/image_grad.py index d17f1a87d9759d5e83393f40e9e027dee8c15979..093843cd5bc0b7c2281a0c9ddf52d93ea3faede3 100644 --- a/tensorflow/python/ops/image_grad.py +++ b/tensorflow/python/ops/image_grad.py @@ -61,15 +61,10 @@ def _ResizeBilinearGrad(op, grad): Returns: The gradients w.r.t. the input. """ - allowed_types = [dtypes.float32, dtypes.float64] - grad0 = None - if op.inputs[0].dtype in allowed_types: - # pylint: disable=protected-access - grad0 = gen_image_ops._resize_bilinear_grad( - grad, - op.inputs[0], - align_corners=op.get_attr("align_corners")) - # pylint: enable=protected-access + # pylint: disable=protected-access + grad0 = gen_image_ops._resize_bilinear_grad( + grad, op.inputs[0], align_corners=op.get_attr("align_corners")) + # pylint: enable=protected-access return [grad0, None] diff --git a/tensorflow/python/ops/image_grad_test.py b/tensorflow/python/ops/image_grad_test.py index 05e8fa1d72851caee522bba470bb40f430152464..75d00c8ed17c26c2c1acb4d92961a2206d959ebb 100644 --- a/tensorflow/python/ops/image_grad_test.py +++ b/tensorflow/python/ops/image_grad_test.py @@ -142,18 +142,6 @@ class ResizeBilinearOpTest(test.TestCase): input_tensor, in_shape, resize_out, out_shape, x_init_value=x) self.assertLess(err, 1e-3) - def testGradOnUnsupportedType(self): - in_shape = [1, 4, 6, 1] - out_shape = [1, 2, 3, 1] - - x = np.arange(0, 24).reshape(in_shape).astype(np.uint8) - - with self.test_session(): - input_tensor = constant_op.constant(x, shape=in_shape) - resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) - grad = gradients_impl.gradients(input_tensor, [resize_out]) - self.assertEqual([None], grad) - def testCompareGpuVsCpu(self): in_shape = [2, 4, 6, 3] out_shape = [2, 8, 16, 3] @@ -172,6 +160,26 @@ class ResizeBilinearOpTest(test.TestCase): self.assertAllClose(grad[False], grad[True], rtol=1e-4, atol=1e-4) + def testTypes(self): + in_shape = [1, 4, 6, 1] + out_shape = [1, 2, 3, 1] + x = np.arange(0, 24).reshape(in_shape) + + with self.test_session() as sess: + for dtype in [np.float16, np.float32, np.float64]: + input_tensor = constant_op.constant(x.astype(dtype), shape=in_shape) + resize_out = image_ops.resize_bilinear(input_tensor, out_shape[1:3]) + grad = sess.run(gradients_impl.gradients(resize_out, input_tensor))[0] + self.assertAllEqual(in_shape, grad.shape) + # Not using gradient_checker.compute_gradient as I didn't work out + # the changes required to compensate for the lower precision of + # float16 when computing the numeric jacobian. + # Instead, we just test the theoretical jacobian. + self.assertAllEqual([[[[1.], [0.], [1.], [0.], [1.], [0.]], [[0.], [ + 0. + ], [0.], [0.], [0.], [0.]], [[1.], [0.], [1.], [0.], [1.], [0.]], + [[0.], [0.], [0.], [0.], [0.], [0.]]]], grad) + class ResizeBicubicOpTest(test.TestCase): diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py index de12c5f63f4357e0982dd2e16999caf2de0b30f8..ae52d32fea1c872e588c4122f5e73198e4dfe9ad 100644 --- a/tensorflow/python/ops/image_ops.py +++ b/tensorflow/python/ops/image_ops.py @@ -26,6 +26,7 @@ See the @{$python/image} guide. @@extract_jpeg_shape @@decode_png @@encode_png +@@is_jpeg @@decode_image @@resize_images @@resize_area diff --git a/tensorflow/python/ops/image_ops_impl.py b/tensorflow/python/ops/image_ops_impl.py index 14a38f25d1028577b553623ebfddfb5c683ad093..58c18c6696d64ccca4ebfaa07242d3c7789116e4 100644 --- a/tensorflow/python/ops/image_ops_impl.py +++ b/tensorflow/python/ops/image_ops_impl.py @@ -167,6 +167,28 @@ def _Assert3DImage(image): _Check3DImage(image, require_static=False), image) +def _AssertAtLeast3DImage(image): + """Assert that we are working with a properly shaped image. + + Performs the check statically if possible (i.e. if the shape + is statically known). Otherwise adds a control dependency + to an assert op that checks the dynamic shape. + + Args: + image: >= 3-D Tensor of size [*, height, width, depth] + + Raises: + ValueError: if image.shape is not a [>= 3] vector. + + Returns: + If the shape of `image` could be verified statically, `image` is + returned unchanged, otherwise there will be a control dependency + added that asserts the correct dynamic shape. + """ + return control_flow_ops.with_dependencies( + _CheckAtLeast3DImage(image, require_static=False), image) + + def _CheckAtLeast3DImage(image, require_static=True): """Assert that we are working with properly shaped image. @@ -292,108 +314,185 @@ def random_flip_left_right(image, seed=None): def flip_left_right(image): """Flip an image horizontally (left to right). - Outputs the contents of `image` flipped along the second dimension, which is - `width`. + Outputs the contents of `image` flipped along the width dimension. See also `reverse()`. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ - with ops.name_scope(None, 'flip_left_right', [image]) as scope: + with ops.name_scope(None, 'flip_left_right', [image]): image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - return fix_image_flip_shape(image, array_ops.reverse( - image, [1], name=scope)) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return fix_image_flip_shape(image, array_ops.reverse(image, [1])) + elif shape.ndims == 4: + return array_ops.reverse(image, [2]) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.flip_up_down') def flip_up_down(image): """Flip an image vertically (upside down). - Outputs the contents of `image` flipped along the first dimension, which is - `height`. + Outputs the contents of `image` flipped along the height dimension. See also `reverse()`. Args: - image: A 3-D tensor of shape `[height, width, channels].` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. Returns: - A 3-D tensor of the same type and shape as `image`. + A tensor of the same type and shape as `image`. Raises: ValueError: if the shape of `image` not supported. """ - with ops.name_scope(None, 'flip_up_down', [image]) as scope: + with ops.name_scope(None, 'flip_up_down', [image]): image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - return fix_image_flip_shape(image, array_ops.reverse( - image, [0], name=scope)) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return fix_image_flip_shape(image, array_ops.reverse(image, [0])) + elif shape.ndims == 4: + return array_ops.reverse(image, [1]) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.rot90') def rot90(image, k=1, name=None): - """Rotate an image counter-clockwise by 90 degrees. + """Rotate image(s) counter-clockwise by 90 degrees. Args: - image: A 3-D tensor of shape `[height, width, channels]`. + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. k: A scalar integer. The number of times the image is rotated by 90 degrees. name: A name for this operation (optional). Returns: - A rotated 3-D tensor of the same type and shape as `image`. + A rotated tensor of the same type and shape as `image`. + + Raises: + ValueError: if the shape of `image` not supported. """ with ops.name_scope(name, 'rot90', [image, k]) as scope: image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) + image = _AssertAtLeast3DImage(image) k = ops.convert_to_tensor(k, dtype=dtypes.int32, name='k') k.get_shape().assert_has_rank(0) k = math_ops.mod(k, 4) - def _rot90(): - return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2]) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return _rot90_3D(image, k, scope) + elif shape.ndims == 4: + return _rot90_4D(image, k, scope) + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') - def _rot180(): - return array_ops.reverse_v2(image, [0, 1]) - def _rot270(): - return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1]) +def _rot90_3D(image, k, name_scope): + """Rotate image counter-clockwise by 90 degrees `k` times. - cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), - (math_ops.equal(k, 3), _rot270)] + Args: + image: 3-D Tensor of shape `[height, width, channels]`. + k: A scalar integer. The number of times the image is rotated by 90 degrees. + name_scope: A valid TensorFlow name scope. + + Returns: + A 3-D tensor of the same type and shape as `image`. + + """ + + def _rot90(): + return array_ops.transpose(array_ops.reverse_v2(image, [1]), [1, 0, 2]) + + def _rot180(): + return array_ops.reverse_v2(image, [0, 1]) + + def _rot270(): + return array_ops.reverse_v2(array_ops.transpose(image, [1, 0, 2]), [1]) + + cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), + (math_ops.equal(k, 3), _rot270)] + + result = control_flow_ops.case( + cases, default=lambda: image, exclusive=True, name=name_scope) + result.set_shape([None, None, image.get_shape()[2]]) + return result + + +def _rot90_4D(images, k, name_scope): + """Rotate batch of images counter-clockwise by 90 degrees `k` times. + + Args: + images: 4-D Tensor of shape `[height, width, channels]`. + k: A scalar integer. The number of times the images are rotated by 90 + degrees. + name_scope: A valid TensorFlow name scope. + + Returns: + A 4-D tensor of the same type and shape as `images`. + + """ - ret = control_flow_ops.case( - cases, default=lambda: image, exclusive=True, name=scope) - ret.set_shape([None, None, image.get_shape()[2]]) - return ret + def _rot90(): + return array_ops.transpose(array_ops.reverse_v2(images, [2]), [0, 2, 1, 3]) + def _rot180(): + return array_ops.reverse_v2(images, [1, 2]) + def _rot270(): + return array_ops.reverse_v2(array_ops.transpose(images, [0, 2, 1, 3]), [2]) + + cases = [(math_ops.equal(k, 1), _rot90), (math_ops.equal(k, 2), _rot180), + (math_ops.equal(k, 3), _rot270)] + + result = control_flow_ops.case( + cases, default=lambda: images, exclusive=True, name=name_scope) + shape = result.get_shape() + result.set_shape([shape[0], None, None, shape[3]]) + return result @tf_export('image.transpose_image') def transpose_image(image): - """Transpose an image by swapping the first and second dimension. + """Transpose image(s) by swapping the height and width dimension. See also `transpose()`. Args: - image: 3-D tensor of shape `[height, width, channels]` + image: 4-D Tensor of shape `[batch, height, width, channels]` or + 3-D Tensor of shape `[height, width, channels]`. Returns: - A 3-D tensor of shape `[width, height, channels]` + If `image` was 4-D, a 4-D float Tensor of shape + `[batch, width, height, channels]` + If `image` was 3-D, a 3-D float Tensor of shape + `[width, height, channels]` Raises: ValueError: if the shape of `image` not supported. """ - with ops.name_scope(None, 'transpose_image', [image]) as scope: + with ops.name_scope(None, 'transpose_image', [image]): image = ops.convert_to_tensor(image, name='image') - image = _Assert3DImage(image) - return array_ops.transpose(image, [1, 0, 2], name=scope) + image = _AssertAtLeast3DImage(image) + shape = image.get_shape() + if shape.ndims == 3 or shape.ndims is None: + return array_ops.transpose(image, [1, 0, 2], name='transpose_image') + elif shape.ndims == 4: + return array_ops.transpose(image, [0, 2, 1, 3], name='transpose_image') + else: + raise ValueError('\'image\' must have either 3 or 4 dimensions.') @tf_export('image.central_crop') @@ -1026,9 +1125,9 @@ def adjust_contrast(images, contrast_factor): def adjust_gamma(image, gamma=1, gain=1): """Performs Gamma Correction on the input image. - Also known as Power Law Transform. This function transforms the - input image pixelwise according to the equation Out = In**gamma - after scaling each pixel to the range 0 to 1. + Also known as Power Law Transform. This function transforms the + input image pixelwise according to the equation `Out = In**gamma` + after scaling each pixel to the range 0 to 1. Args: image : A Tensor. @@ -1339,6 +1438,26 @@ def adjust_saturation(image, saturation_factor, name=None): orig_dtype) +@tf_export('image.is_jpeg') +def is_jpeg(contents, name=None): + r"""Convenience function to check if the 'contents' encodes a JPEG image. + + Args: + contents: 0-D `string`. The encoded image bytes. + name: A name for the operation (optional) + + Returns: + A scalar boolean tensor indicating if 'contents' may be a JPEG image. + is_jpeg is susceptible to false positives. + """ + # Normal JPEGs start with \xff\xd8\xff\xe0 + # JPEG with EXIF stats with \xff\xd8\xff\xe1 + # Use \xff\xd8\xff to cover both. + with ops.name_scope(name, 'is_jpeg'): + substr = string_ops.substr(contents, 0, 3) + return math_ops.equal(substr, b'\xff\xd8\xff', name=name) + + @tf_export('image.decode_image') def decode_image(contents, channels=None, name=None): """Convenience function for `decode_bmp`, `decode_gif`, `decode_jpeg`, @@ -1427,8 +1546,8 @@ def decode_image(contents, channels=None, name=None): # Decode normal JPEG images (start with \xff\xd8\xff\xe0) # as well as JPEG images with EXIF data (start with \xff\xd8\xff\xe1). - is_jpeg = math_ops.equal(substr, b'\xff\xd8\xff', name='is_jpeg') - return control_flow_ops.cond(is_jpeg, _jpeg, check_png, name='cond_jpeg') + return control_flow_ops.cond( + is_jpeg(contents), _jpeg, check_png, name='cond_jpeg') @tf_export('image.total_variation') diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index 91a74376520479594ad5cb7897730717f252d228..b8c4b27c162acdd86d88da641ff8afffaa5a9e6a 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -934,7 +934,7 @@ class AdjustSaturationTest(test_util.TensorFlowTestCase): class FlipTransposeRotateTest(test_util.TensorFlowTestCase): - def testIdempotentLeftRight(self): + def testInvolutionLeftRight(self): x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) @@ -942,6 +942,16 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, x_np) + def testInvolutionLeftRightWithBatch(self): + x_np = np.array( + [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf)) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + def testLeftRight(self): x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1]) @@ -953,9 +963,24 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, y_np) + def testLeftRightWithBatch(self): + x_np = np.array( + [[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + y_np = np.array( + [[[3, 2, 1], [3, 2, 1]], [[3, 2, 1], [3, 2, 1]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_left_right(x_tf) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + def testRandomFlipLeftRight(self): x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1]) + seed = 42 with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) @@ -964,7 +989,7 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): count_flipped = 0 count_unflipped = 0 - for _ in range(50): + for _ in range(100): y_tf = y.eval() if y_tf[0][0] == 1: self.assertAllEqual(y_tf, x_np) @@ -972,10 +997,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): else: self.assertAllEqual(y_tf, y_np) count_flipped += 1 - self.assertGreaterEqual(count_flipped, 1) - self.assertGreaterEqual(count_unflipped, 1) - def testIdempotentUpDown(self): + # 100 trials + # Mean: 50 + # Std Dev: ~5 + # Six Sigma: 50 - (5 * 6) = 20 + self.assertGreaterEqual(count_flipped, 20) + self.assertGreaterEqual(count_unflipped, 20) + + def testInvolutionUpDown(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): @@ -984,6 +1014,17 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, x_np) + def testInvolutionUpDownWithBatch(self): + x_np = np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf)) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + def testUpDown(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) @@ -995,17 +1036,31 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, y_np) + def testUpDownWithBatch(self): + x_np = np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + y_np = np.array( + [[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.flip_up_down(x_tf) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + def testRandomFlipUpDown(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): x_tf = constant_op.constant(x_np, shape=x_np.shape) - y = image_ops.random_flip_up_down(x_tf) + y = image_ops.random_flip_up_down(x_tf, seed=42) self.assertTrue(y.op.name.startswith("random_flip_up_down")) count_flipped = 0 count_unflipped = 0 - for _ in range(50): + for _ in range(100): y_tf = y.eval() if y_tf[0][0] == 1: self.assertAllEqual(y_tf, x_np) @@ -1013,10 +1068,15 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): else: self.assertAllEqual(y_tf, y_np) count_flipped += 1 - self.assertGreaterEqual(count_flipped, 1) - self.assertGreaterEqual(count_unflipped, 1) - def testIdempotentTranspose(self): + # 100 trials + # Mean: 50 + # Std Dev: ~5 + # Six Sigma: 50 - (5 * 6) = 20 + self.assertGreaterEqual(count_flipped, 20) + self.assertGreaterEqual(count_unflipped, 20) + + def testInvolutionTranspose(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) with self.test_session(use_gpu=True): @@ -1025,6 +1085,17 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, x_np) + def testInvolutionTransposeWithBatch(self): + x_np = np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.transpose_image(image_ops.transpose_image(x_tf)) + y_tf = y.eval() + self.assertAllEqual(y_tf, x_np) + def testTranspose(self): x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1]) y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1]) @@ -1036,15 +1107,34 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_tf = y.eval() self.assertAllEqual(y_tf, y_np) + def testTransposeWithBatch(self): + x_np = np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], + dtype=np.uint8).reshape([2, 2, 3, 1]) + + y_np = np.array( + [[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]], + dtype=np.uint8).reshape([2, 3, 2, 1]) + + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np, shape=x_np.shape) + y = image_ops.transpose_image(x_tf) + y_tf = y.eval() + self.assertAllEqual(y_tf, y_np) + def testPartialShapes(self): p_unknown_rank = array_ops.placeholder(dtypes.uint8) - p_unknown_dims = array_ops.placeholder( + p_unknown_dims_3 = array_ops.placeholder( dtypes.uint8, shape=[None, None, None]) + p_unknown_dims_4 = array_ops.placeholder( + dtypes.uint8, shape=[None, None, None, None]) p_unknown_width = array_ops.placeholder(dtypes.uint8, shape=[64, None, 3]) - + p_unknown_batch = array_ops.placeholder( + dtypes.uint8, shape=[None, 64, 64, 3]) p_wrong_rank = array_ops.placeholder(dtypes.uint8, shape=[None, None]) p_zero_dim = array_ops.placeholder(dtypes.uint8, shape=[64, 0, 3]) + #Ops that support 3D input for op in [ image_ops.flip_left_right, image_ops.flip_up_down, image_ops.random_flip_left_right, image_ops.random_flip_up_down, @@ -1052,16 +1142,35 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): ]: transformed_unknown_rank = op(p_unknown_rank) self.assertEqual(3, transformed_unknown_rank.get_shape().ndims) - transformed_unknown_dims = op(p_unknown_dims) - self.assertEqual(3, transformed_unknown_dims.get_shape().ndims) + transformed_unknown_dims_3 = op(p_unknown_dims_3) + self.assertEqual(3, transformed_unknown_dims_3.get_shape().ndims) transformed_unknown_width = op(p_unknown_width) self.assertEqual(3, transformed_unknown_width.get_shape().ndims) - with self.assertRaisesRegexp(ValueError, "must be three-dimensional"): - op(p_wrong_rank) with self.assertRaisesRegexp(ValueError, "must be > 0"): op(p_zero_dim) + #Ops that support 4D input + for op in [ + image_ops.flip_left_right, image_ops.flip_up_down, + image_ops.transpose_image, image_ops.rot90 + ]: + transformed_unknown_dims_4 = op(p_unknown_dims_4) + self.assertEqual(4, transformed_unknown_dims_4.get_shape().ndims) + transformed_unknown_batch = op(p_unknown_batch) + self.assertEqual(4, transformed_unknown_batch.get_shape().ndims) + with self.assertRaisesRegexp(ValueError, + "must be at least three-dimensional"): + op(p_wrong_rank) + + for op in [ + image_ops.random_flip_left_right, + image_ops.random_flip_up_down, + ]: + with self.assertRaisesRegexp(ValueError, "must be three-dimensional"): + op(p_wrong_rank) + + def testRot90GroupOrder(self): image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3]) with self.test_session(use_gpu=True): @@ -1070,6 +1179,14 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): rotated = image_ops.rot90(rotated) self.assertAllEqual(image, rotated.eval()) + def testRot90GroupOrderWithBatch(self): + image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3]) + with self.test_session(use_gpu=True): + rotated = image + for _ in xrange(4): + rotated = image_ops.rot90(rotated) + self.assertAllEqual(image, rotated.eval()) + def testRot90NumpyEquivalence(self): image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3]) with self.test_session(use_gpu=True): @@ -1079,6 +1196,14 @@ class FlipTransposeRotateTest(test_util.TensorFlowTestCase): y_np = np.rot90(image, k=k) self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k})) + def testRot90NumpyEquivalenceWithBatch(self): + image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3]) + with self.test_session(use_gpu=True): + k_placeholder = array_ops.placeholder(dtypes.int32, shape=[]) + y_tf = image_ops.rot90(image, k_placeholder) + for k in xrange(4): + y_np = np.rot90(image, k=k, axes=(1, 2)) + self.assertAllEqual(y_np, y_tf.eval({k_placeholder: k})) class RandomFlipTest(test_util.TensorFlowTestCase): @@ -1909,7 +2034,8 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): bounding_box = constant_op.constant( [[[0.0, 0.0, 1.0, 1.0]]], shape=[1, 1, 4], - dtype=dtypes.float32,) + dtype=dtypes.float32, + ) begin, end, bbox_for_drawing = image_ops.sample_distorted_bounding_box( image_size=image_size, bounding_boxes=bounding_box, @@ -1924,6 +2050,7 @@ class SelectDistortedCropBoxTest(test_util.TensorFlowTestCase): end = end.eval() bbox_for_drawing = bbox_for_drawing.eval() + class ResizeImagesTest(test_util.TensorFlowTestCase): OPTIONS = [ @@ -3173,12 +3300,11 @@ class NonMaxSuppressionTest(test_util.TensorFlowTestCase): # The boxes is of shape [num_boxes, 4], and the scores is # of shape [num_boxes]. So an error will thrown. - with self.assertRaisesRegexp( - ValueError, 'Dimensions must be equal, but are 1 and 2'): + with self.assertRaisesRegexp(ValueError, + "Dimensions must be equal, but are 1 and 2"): boxes = constant_op.constant([[0.0, 0.0, 1.0, 1.0]]) scores = constant_op.constant([0.9, 0.75]) - selected_indices = image_ops.non_max_suppression( - boxes, scores, 3, 0.5) + selected_indices = image_ops.non_max_suppression(boxes, scores, 3, 0.5) # The scores should be 1D of shape [num_boxes]. with self.assertRaisesRegexp(ValueError, diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index b3ec3d5b7cf45ac0b2672eea9a4586b2c3295897..e180e830263c44fb5ae290d307f1ef80106c31d5 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -67,7 +67,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): operator = LinearOperatorDiag(diag) # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible - # since the batch dimensions, [2, 1], are brodcast to + # since the batch dimensions, [2, 1], are broadcast to # operator.batch_shape = [2, 3]. y = tf.random_normal(shape=[2, 1, 4, 2]) x = operator.solve(y) diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py index eadbc1b7c3b6e66aa76c9afd860b2274ac1976ae..3757109c956dfedc64ac4cda4ad13a4cfa601418 100644 --- a/tensorflow/python/ops/logging_ops.py +++ b/tensorflow/python/ops/logging_ops.py @@ -356,3 +356,4 @@ ops.NotDifferentiable("AudioSummary") ops.NotDifferentiable("AudioSummaryV2") ops.NotDifferentiable("MergeSummary") ops.NotDifferentiable("ScalarSummary") +ops.NotDifferentiable("Timestamp") diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py index 8e003fb7ac6462fb611a020e86b06b5987af9546..7386976e93fbb82f38550f50429af878fadda813 100644 --- a/tensorflow/python/ops/losses/losses_impl.py +++ b/tensorflow/python/ops/losses/losses_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import confusion_matrix @@ -88,6 +89,14 @@ def _safe_div(numerator, denominator, name="value"): Returns: The element-wise value of the numerator divided by the denominator. """ + if isinstance(denominator, float): + if math_ops.equal(denominator, 0.0): + return ops.convert_to_tensor(0.0, dtype=numerator.dtype) + return math_ops.div(numerator, denominator) + if context.in_eager_mode() and denominator._rank() == 0: # pylint: disable=protected-access + if math_ops.equal(denominator, 0.0): + return ops.convert_to_tensor(0.0, dtype=numerator.dtype) + return math_ops.div(numerator, denominator) return array_ops.where( math_ops.greater(denominator, 0), math_ops.div(numerator, array_ops.where( @@ -134,6 +143,10 @@ def _num_present(losses, weights, per_batch=False): `per_batch` is `True`, the value is returned as a tensor of size `[batch_size]`. Otherwise, a single scalar tensor is returned. """ + if ((isinstance(weights, float) and weights != 0.0) or + (context.in_eager_mode() and weights._rank() == 0 # pylint: disable=protected-access + and not math_ops.equal(weights, 0.0))): + return _num_elements(losses) with ops.name_scope(None, "num_present", (losses, weights)) as scope: weights = math_ops.to_float(weights) present = array_ops.where( @@ -143,8 +156,10 @@ def _num_present(losses, weights, per_batch=False): present = weights_broadcast_ops.broadcast_weights(present, losses) if per_batch: return math_ops.reduce_sum( - present, axis=math_ops.range(1, array_ops.rank(present)), - keepdims=True, name=scope) + present, + axis=math_ops.range(1, array_ops.rank(present)), + keepdims=True, + name=scope) return math_ops.reduce_sum(present, name=scope) @@ -421,8 +436,12 @@ def huber_loss(labels, predictions, weights=1.0, delta=1.0, scope=None, # expression when abs_error == delta is 0 (for tf.maximum it would be 1). # This is necessary to avoid doubling the gradient, since there is already a # nonzero contribution to the gradient from the quadratic term. - linear = (abs_error - quadratic) - losses = 0.5 * quadratic * quadratic + delta * linear + linear = math_ops.subtract(abs_error, quadratic) + losses = math_ops.add( + math_ops.multiply( + ops.convert_to_tensor(0.5, dtype=quadratic.dtype), + math_ops.multiply(quadratic, quadratic)), + math_ops.multiply(delta, linear)) return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction) @@ -542,7 +561,8 @@ def mean_pairwise_squared_error( reduction_indices = math_ops.range(1, array_ops.rank(diffs)) sum_squares_diff_per_batch = math_ops.reduce_sum( - math_ops.square(diffs), reduction_indices=reduction_indices, + math_ops.square(diffs), + reduction_indices=reduction_indices, keepdims=True) num_present_per_batch = _num_present(diffs, weights, per_batch=True) @@ -634,7 +654,7 @@ def sigmoid_cross_entropy( Args: multi_class_labels: `[batch_size, num_classes]` target integer labels in - `(0, 1)`. + `{0, 1}`. logits: Float `[batch_size, num_classes]` logits outputs of the network. weights: Optional `Tensor` whose rank is either 0, or the same rank as `labels`, and must be broadcastable to `labels` (i.e., all dimensions must @@ -731,7 +751,6 @@ def softmax_cross_entropy( losses = nn.softmax_cross_entropy_with_logits_v2( labels=onehot_labels, logits=logits, name="xentropy") - return compute_weighted_loss( losses, weights, scope, loss_collection, reduction=reduction) diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index c6cc4e186074e71b8742e4aa5b69a699f77f250e..69afa618e2fae146f75fd70dee4b04d447c843d3 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -35,6 +35,12 @@ def _safe_shape_div(x, y): return x // math_ops.maximum(y, 1) +@ops.RegisterGradient("ArgMax") +def _ArgMaxGrad(op, grad): + del op, grad + return [None, None] + + @ops.RegisterGradient("Sum") def _SumGrad(op, grad): """Gradient for Sum.""" @@ -877,11 +883,13 @@ def _MulGrad(op, grad): sx = array_ops.shape(x) sy = array_ops.shape(y) rx, ry = gen_array_ops._broadcast_gradient_args(sx, sy) - # pylint: enable=protected-access x = math_ops.conj(x) y = math_ops.conj(y) - return (array_ops.reshape(math_ops.reduce_sum(grad * y, rx), sx), - array_ops.reshape(math_ops.reduce_sum(x * grad, ry), sy)) + return (array_ops.reshape( + math_ops.reduce_sum(gen_math_ops._mul(grad, y), rx), sx), + array_ops.reshape( + math_ops.reduce_sum(gen_math_ops._mul(x, grad), ry), sy)) + # pylint: enable=protected-access @ops.RegisterGradient("Div") @@ -1054,18 +1062,20 @@ def _MatMulGrad(op, grad): t_b = op.get_attr("transpose_b") a = math_ops.conj(op.inputs[0]) b = math_ops.conj(op.inputs[1]) + # pylint: disable=protected-access if not t_a and not t_b: - grad_a = math_ops.matmul(grad, b, transpose_b=True) - grad_b = math_ops.matmul(a, grad, transpose_a=True) + grad_a = gen_math_ops._mat_mul(grad, b, transpose_b=True) + grad_b = gen_math_ops._mat_mul(a, grad, transpose_a=True) elif not t_a and t_b: - grad_a = math_ops.matmul(grad, b) - grad_b = math_ops.matmul(grad, a, transpose_a=True) + grad_a = gen_math_ops._mat_mul(grad, b) + grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True) elif t_a and not t_b: - grad_a = math_ops.matmul(b, grad, transpose_b=True) - grad_b = math_ops.matmul(a, grad) + grad_a = gen_math_ops._mat_mul(b, grad, transpose_b=True) + grad_b = gen_math_ops._mat_mul(a, grad) elif t_a and t_b: - grad_a = math_ops.matmul(b, grad, transpose_a=True, transpose_b=True) - grad_b = math_ops.matmul(grad, a, transpose_a=True, transpose_b=True) + grad_a = gen_math_ops._mat_mul(b, grad, transpose_a=True, transpose_b=True) + grad_b = gen_math_ops._mat_mul(grad, a, transpose_a=True, transpose_b=True) + # pylint: enable=protected-access return grad_a, grad_b diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index da9957aa2a5463a37bba155597600a340ee4f1e6..ed11fe5348d35dc2497ab7e624453b9cc956d376 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -161,14 +161,11 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import gen_sparse_ops from tensorflow.python.ops import gen_spectral_ops -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.ops import state_ops # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.python.ops.gen_math_ops import * @@ -901,6 +898,40 @@ def to_bfloat16(x, name="ToBFloat16"): return cast(x, dtypes.bfloat16, name=name) +@tf_export("to_complex64") +def to_complex64(x, name="ToComplex64"): + """Casts a tensor to type `complex64`. + + Args: + x: A `Tensor` or `SparseTensor`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor` with same shape as `x` with type `complex64`. + + Raises: + TypeError: If `x` cannot be cast to the `complex64`. + """ + return cast(x, dtypes.complex64, name=name) + + +@tf_export("to_complex128") +def to_complex128(x, name="ToComplex128"): + """Casts a tensor to type `complex128`. + + Args: + x: A `Tensor` or `SparseTensor`. + name: A name for the operation (optional). + + Returns: + A `Tensor` or `SparseTensor` with same shape as `x` with type `complex128`. + + Raises: + TypeError: If `x` cannot be cast to the `complex128`. + """ + return cast(x, dtypes.complex128, name=name) + + ops.Tensor._override_operator("__neg__", gen_math_ops._neg) ops.Tensor._override_operator("__abs__", abs) # __invert__ corresponds to the ~ operator. Here we follow the numpy convention @@ -1295,9 +1326,9 @@ def _ReductionDims(x, axis, reduction_indices): return axis else: # Fast path: avoid creating Rank and Range ops if ndims is known. - if isinstance(x, ops.Tensor) and x.get_shape().ndims is not None: + if isinstance(x, ops.Tensor) and x._rank() is not None: # pylint: disable=protected-access return constant_op.constant( - np.arange(x.get_shape().ndims), dtype=dtypes.int32) + np.arange(x._rank()), dtype=dtypes.int32) # pylint: disable=protected-access if (isinstance(x, sparse_tensor.SparseTensor) and x.dense_shape.get_shape().is_fully_defined()): rank = x.dense_shape.get_shape()[0].value # sparse.dense_shape is 1-D. @@ -2184,14 +2215,12 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): Optionally, pass `shape` and `tensor_dtype` for shape and type checking, otherwise, these are inferred. - NOTE: This operation is not differentiable and cannot be used if inputs depend - on trainable variables. Please use `tf.add_n` for such cases. + `tf.accumulate_n` performs the same operation as `tf.add_n`, but does not + wait for all of its inputs to be ready before beginning to sum. This can + save memory if inputs are ready at different times, since minimum temporary + storage is proportional to the output size rather than the inputs size. - Aside from differentiability, `tf.accumulate_n` performs the same operation as - `tf.add_n`, but does not wait for all of its inputs to be ready before - beginning to sum. This can save memory if inputs are ready at different times, - since minimum temporary storage is proportional to the output size rather than - the inputs size. + `accumulate_n` is differentiable (but wasn't previous to TensorFlow 1.7). For example: @@ -2201,8 +2230,9 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): tf.accumulate_n([a, b, a]) # [[7, 4], [6, 14]] # Explicitly pass shape and type - tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) # [[7, 4], - # [6, 14]] + tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + # [[7, 4], + # [6, 14]] ``` Args: @@ -2218,20 +2248,17 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): ValueError: If `inputs` don't all have same shape and dtype or the shape cannot be inferred. """ - if context.in_eager_mode(): - # TODO(apassos) remove this once the lifetime of eager variables gets - # addressed. - raise ValueError("accumulate_n not supported in eager mode") + def _input_error(): + return ValueError( + "inputs must be a list of at least one Tensor with the " + "same dtype and shape") if not inputs or not isinstance(inputs, (list, tuple)): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) if not all(isinstance(x, ops.Tensor) for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() if not all(x.dtype == inputs[0].dtype for x in inputs): - raise ValueError("inputs must be a list of at least one Tensor with the " - "same dtype and shape") + raise _input_error() if shape is not None: shape = tensor_shape.as_shape(shape) else: @@ -2239,27 +2266,31 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None): for input_tensor in inputs: if isinstance(input_tensor, ops.Tensor): shape = shape.merge_with(input_tensor.get_shape()) - if tensor_dtype is None: - tensor_dtype = inputs[0].dtype - if tensor_dtype != inputs[0].dtype: - raise TypeError("tensor_dtype is {}, but input is of type {}".format( - tensor_dtype, inputs[0].dtype)) - if len(inputs) == 1: + + # tensor_dtype is for safety only; operator's output type computed in C++ + if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: + raise TypeError("tensor_dtype is {}, but input is of type {}" + .format(tensor_dtype, inputs[0].dtype)) + + if len(inputs) == 1 and name is None: return inputs[0] - with ops.name_scope(name, "AccumulateN", inputs) as name: - var = gen_state_ops._temporary_variable( - shape=tensor_shape.vector(0), dtype=tensor_dtype) - with ops.colocate_with(var): - zeros = array_ops.zeros_like(gen_control_flow_ops._merge(inputs)[0]) - zeros.set_shape(shape) - ref = state_ops.assign(var, zeros, validate_shape=False) - update_ops = [ - state_ops.assign_add(ref, input_tensor, use_locking=True) - for input_tensor in inputs - ] - with ops.control_dependencies(update_ops): - return gen_state_ops._destroy_temporary_variable( - ref, var_name=var.op.name, name=name) + elif len(inputs) == 1 and name is not None: + return array_ops.identity(inputs[0], name=name) + elif context.in_eager_mode(): + # TemporaryVariable not currently supported in eager mode; fall back + # onto AddN for now. + # TODO(frreiss) remove this once the lifetime of eager variables gets + # addressed + return add_n(inputs, name=name) + else: + return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) # pylint: disable=protected-access + + +@ops.RegisterGradient("AccumulateNV2") +def _accumulate_n_grad(op, grad): + """Same as gradient for AddN. Copies the gradient to all inputs.""" + # Not broadcasting. + return [grad] * len(op.inputs) @tf_export("nn.sigmoid", "sigmoid") diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index 7776ff08c4f55c43947010f313d8167596b15db7..043c0e30cd8476b1a91e136df60edfbedf85ab24 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -672,7 +672,7 @@ def auc(labels, x = fp_rate y = rec else: # curve == 'PR'. - prec = math_ops.div(tp + epsilon, tp + fp + epsilon) + prec = math_ops.div(tp, tp + fp + epsilon) x = rec y = prec if summation_method == 'trapezoidal': @@ -923,8 +923,8 @@ def mean_per_class_accuracy(labels, weights = array_ops.reshape(weights, [-1]) weights = math_ops.to_float(weights) - is_correct = is_correct * weights - ones = ones * weights + is_correct *= weights + ones *= weights update_total_op = state_ops.scatter_add(total, labels, ones) update_count_op = state_ops.scatter_add(count, labels, is_correct) @@ -1247,13 +1247,8 @@ def mean_tensor(values, with ops.control_dependencies([values]): update_count_op = state_ops.assign_add(count, num_values) - def compute_mean(total, count, name): - non_zero_count = math_ops.maximum( - count, array_ops.ones_like(count), name=name) - return math_ops.truediv(total, non_zero_count, name=name) - - mean_t = compute_mean(total, count, 'value') - update_op = compute_mean(update_total_op, update_count_op, 'update_op') + mean_t = _safe_div(total, count, 'value') + update_op = _safe_div(update_total_op, update_count_op, 'update_op') if metrics_collections: ops.add_to_collections(metrics_collections, mean_t) diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py index 5fa5708114fd5cda6afbca78fa0debf68f0252cc..254f0051a4e878f4c405f5b1c047c9c0cdcef043 100644 --- a/tensorflow/python/ops/nn_impl.py +++ b/tensorflow/python/ops/nn_impl.py @@ -1345,4 +1345,4 @@ def sampled_softmax_loss(weights, sampled_losses = nn_ops.softmax_cross_entropy_with_logits( labels=labels, logits=logits) # sampled_losses is a [batch_size] tensor. - return sampled_losses + return sampled_losses \ No newline at end of file diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 47f48a7e168acd6788954e8e7117993d57c63304..8fbe698914e5f2fa8f821feed82c33fc77e35e21 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -2215,6 +2215,31 @@ def xw_plus_b_v1(x, weights, biases, name=None): # pylint: disable=invalid-name return bias_add_v1(mm, biases, name=name) +def _get_noise_shape(x, noise_shape): + # If noise_shape is none return immediately. + if noise_shape is None: + return array_ops.shape(x) + + try: + # Best effort to figure out the intended shape. + # If not possible, let the op to handle it. + # In eager mode exception will show up. + noise_shape_ = tensor_shape.as_shape(noise_shape) + except (TypeError, ValueError): + return noise_shape + + if x.shape.dims is not None and len(x.shape.dims) == len(noise_shape_.dims): + new_dims = [] + for i, dim in enumerate(x.shape.dims): + if noise_shape_.dims[i].value is None and dim.value is not None: + new_dims.append(dim.value) + else: + new_dims.append(noise_shape_.dims[i].value) + return tensor_shape.TensorShape(new_dims) + + return noise_shape + + @tf_export("nn.dropout") def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: disable=invalid-name """Computes dropout. @@ -2265,7 +2290,8 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di if tensor_util.constant_value(keep_prob) == 1: return x - noise_shape = noise_shape if noise_shape is not None else array_ops.shape(x) + noise_shape = _get_noise_shape(x, noise_shape) + # uniform [keep_prob, 1.0 + keep_prob) random_tensor = keep_prob random_tensor += random_ops.random_uniform( @@ -2380,7 +2406,7 @@ def conv1d(value, Args: value: A 3D `Tensor`. Must be of type `float16` or `float32`. - filters: A 3D `Tensor`. Must have the same type as `input`. + filters: A 3D `Tensor`. Must have the same type as `value`. stride: An `integer`. The number of entries by which the filter is moved right at each step. padding: 'SAME' or 'VALID' diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py index 5a45bdc1e5e1d38a34176ed9443fcd1713f38e1e..21eea3db25af0d1bcfbc7496665f5535c3f660ea 100644 --- a/tensorflow/python/ops/nn_test.py +++ b/tensorflow/python/ops/nn_test.py @@ -383,6 +383,31 @@ class DropoutTest(test_lib.TestCase): x, keep_prob, noise_shape=array_ops.placeholder(dtypes.int32)) self.assertEqual(x.get_shape(), dropout_x.get_shape()) + def testPartialShapedDropout(self): + x_dim = 40 * 30 + y_dim = 3 + num_iter = 10 + for keep_prob in [0.1, 0.5, 0.8]: + with self.test_session(): + t = constant_op.constant( + 1.0, shape=[x_dim, y_dim], dtype=dtypes.float32) + # Set noise_shape=[None, 1] which means [x_dim, 1]. + dropout = nn_ops.dropout(t, keep_prob, noise_shape=[None, 1]) + self.assertEqual([x_dim, y_dim], dropout.get_shape()) + final_count = 0 + for _ in xrange(0, num_iter): + value = dropout.eval() + final_count += np.count_nonzero(value) + # Verifies that there are only two values: 0 and 1/keep_prob. + sorted_value = np.unique(np.sort(value)) + self.assertEqual(0, sorted_value[0]) + self.assertAllClose(1 / keep_prob, sorted_value[1]) + # Check that we are in the 15% error range + expected_count = x_dim * y_dim * keep_prob * num_iter + rel_error = math.fabs(final_count - expected_count) / expected_count + print(rel_error) + self.assertTrue(rel_error < 0.15) + def testInvalidKeepProb(self): x_dim = 40 y_dim = 30 diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 75cb57f16f28ad3c877a62abc894a1299c4fc160..2d6d0672e03d9435175b0accd7c20dfddae16bcc 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import variables # pylint: disable=wildcard-import from tensorflow.python.ops.gen_resource_variable_ops import * # pylint: enable=wildcard-import +from tensorflow.python.training import checkpointable from tensorflow.python.util import compat @@ -107,13 +108,16 @@ class EagerResourceDeleter(object): """ def __init__(self, handle, handle_device): + if not isinstance(handle, ops.Tensor): + raise ValueError( + ("Passed handle=%s to EagerResourceDeleter. Was expecting a handle " + "Tensor." % (handle,))) self._handle = handle self._handle_device = handle_device def __del__(self): # Resources follow object-identity when executing eagerly, so it is safe to - # delete the resource we have a handle to. Each Graph has a unique container - # name, which prevents resource sharing. + # delete the resource we have a handle to. try: # This resource was created in eager mode. However, this destructor may be # running in graph mode (especially during unit tests). To clean up @@ -344,6 +348,11 @@ class ResourceVariable(variables.Variable): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + self._trainable = trainable if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] @@ -783,38 +792,38 @@ class ResourceVariable(variables.Variable): # TODO(apassos): this here and below is not atomic. Consider making it # atomic if there's a way to do so without a performance cost for those who # don't need it. - with ops.control_dependencies([ - gen_resource_variable_ops.assign_sub_variable_op( - self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), - name=name) - ]): - return self.read_value() + return self._lazy_read(gen_resource_variable_ops.assign_sub_variable_op( + self.handle, + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name)) def assign_add(self, delta, use_locking=None, name=None): - with ops.control_dependencies([ - gen_resource_variable_ops.assign_add_variable_op( - self.handle, - ops.convert_to_tensor(delta, dtype=self.dtype), - name=name) - ]): - return self.read_value() + return self._lazy_read(gen_resource_variable_ops.assign_add_variable_op( + self.handle, + ops.convert_to_tensor(delta, dtype=self.dtype), + name=name)) + + def _lazy_read(self, op): + if hasattr(self, "_trainable") and self._trainable: + tape.watch_variable(self) + return _UnreadVariable( + self._handle, self.dtype, self._handle_device, self._shape, + self._in_graph_mode, + self._handle_deleter if not self._in_graph_mode else None, op) def assign(self, value, use_locking=None, name=None): value_tensor = ops.convert_to_tensor(value, dtype=self.dtype) self._shape.assert_is_compatible_with(value_tensor.shape) - with ops.control_dependencies([ + return self._lazy_read( gen_resource_variable_ops.assign_variable_op( self.handle, value_tensor, - name=name) - ]): - return self.read_value() + name=name)) def _strided_slice_assign(self, begin, end, strides, value, name, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask): - with ops.control_dependencies([ + return self._lazy_read( gen_array_ops.resource_strided_slice_assign( ref=self.handle, begin=begin, @@ -826,9 +835,12 @@ class ResourceVariable(variables.Variable): end_mask=end_mask, ellipsis_mask=ellipsis_mask, new_axis_mask=new_axis_mask, - shrink_axis_mask=shrink_axis_mask) - ]): - return self.value() + shrink_axis_mask=shrink_axis_mask)) + + def __int__(self): + if self.dtype != dtypes.int32 and self.dtype != dtypes.int64: + raise TypeError("Non-integer variable can't be converted to integer.") + return int(self.value().numpy()) def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False): del name @@ -886,6 +898,61 @@ def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access +class _UnreadVariable(ResourceVariable): + """Represents a future for a read of a variable. + + Pretends to be the tensor if anyone looks. + """ + + def __init__(self, handle, dtype, handle_device, # pylint: disable=super-init-not-called + shape, in_graph_mode, deleter, parent_op): + # We do not call super init on purpose. + self._trainable = False + self._save_slice_info = None + self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access + self._in_graph_mode = in_graph_mode + self._handle = handle + self._handle_device = handle_device + self._shape = shape + self._initial_value = None + if isinstance(self._handle, ops.EagerTensor): + self._handle_name = "" + else: + self._handle_name = self._handle.name + self._dtype = dtype + self._constraint = None + self._cached_value = None + self._is_initialized_op = None + self._initializer_op = None + self._parent_op = parent_op + if context.in_graph_mode(): + self._graph_element = self.read_value() + else: + self._graph_element = None + self._handle_deleter = deleter + + def value(self): + return self._read_variable_op() + + def read_value(self): + return self._read_variable_op() + + def _read_variable_op(self): + with ops.control_dependencies([self._parent_op]): + return gen_resource_variable_ops.read_variable_op(self._handle, + self._dtype) + + def set_shape(self, shape): + self._shape = shape + + @property + def op(self): + """The op for this variable.""" + return self._parent_op + +ops.register_tensor_conversion_function(_UnreadVariable, _dense_var_to_tensor) +ops.register_dense_tensor_like_type(_UnreadVariable) + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. @@ -957,3 +1024,9 @@ ops.register_proto_function( proto_type=variable_pb2.VariableDef, to_proto=_to_proto_fn, from_proto=_from_proto_fn) + + +def is_resource_variable(var): + """"Returns True if `var` is to be considered a ResourceVariable.""" + return isinstance(var, ResourceVariable) or hasattr( + var, "_should_act_as_resource_variable") diff --git a/tensorflow/python/ops/script_ops.py b/tensorflow/python/ops/script_ops.py index 551b3b0ed47a936949e91826a1124cad464dc9f3..6fe2f61016775b410045fefcc8764907b8ea39f3 100644 --- a/tensorflow/python/ops/script_ops.py +++ b/tensorflow/python/ops/script_ops.py @@ -33,6 +33,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_script_ops +from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -176,7 +177,10 @@ class CleanupFunc(object): self._token = token def __del__(self): - _py_funcs.remove(self._token) + if _py_funcs is not None: + # If _py_funcs is None, the program is most likely in shutdown, and the + # _py_funcs object has been destroyed already. + _py_funcs.remove(self._token) def _internal_py_func(func, inp, Tout, stateful=None, eager=False, name=None): @@ -264,7 +268,7 @@ def py_func(func, inp, Tout, stateful=True, name=None): """Wraps a python function and uses it as a TensorFlow op. Given a python function `func`, which takes numpy arrays as its - inputs and returns numpy arrays as its outputs, wrap this function as an + arguments and returns numpy arrays as its outputs, wrap this function as an operation in a TensorFlow graph. The following snippet constructs a simple TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation in the graph: @@ -273,8 +277,8 @@ def py_func(func, inp, Tout, stateful=True, name=None): def my_func(x): # x will be a numpy array with the contents of the placeholder below return np.sinh(x) - inp = tf.placeholder(tf.float32) - y = tf.py_func(my_func, [inp], tf.float32) + input = tf.placeholder(tf.float32) + y = tf.py_func(my_func, [input], tf.float32) ``` **N.B.** The `tf.py_func()` operation has the following known limitations: @@ -290,10 +294,12 @@ def py_func(func, inp, Tout, stateful=True, name=None): server (e.g. using `with tf.device():`). Args: - func: A Python function, which accepts a list of NumPy `ndarray` objects - having element types that match the corresponding `tf.Tensor` objects - in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) - having element types that match the corresponding values in `Tout`. + func: A Python function, which accepts `ndarray` objects as arguments and + returns a list of `ndarray` objects (or a single `ndarray`). This function + must accept as many arguments as there are tensors in `inp`, and these + argument types will match the corresponding `tf.Tensor` objects + in `inp`. The returns `ndarray`s must match the number and types defined + `Tout`. Important Note: Input and output numpy `ndarray`s of `func` are not guaranteed to be copies. In some cases their underlying memory will be shared with the corresponding TensorFlow tensors. @@ -313,6 +319,12 @@ def py_func(func, inp, Tout, stateful=True, name=None): Returns: A list of `Tensor` or a single `Tensor` which `func` computes. """ + if context.in_eager_mode(): + result = func(*[x.numpy() for x in inp]) + result = nest.flatten(result) + + return [x if x is None else ops.convert_to_tensor(x) for x in result] + return _internal_py_func( func=func, inp=inp, Tout=Tout, stateful=stateful, eager=False, name=name) diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py index f6d9111009dc4f6a58ac81e7071ed7fe406600fa..b62e556967753dac4418add2864ce4e641dc6b58 100644 --- a/tensorflow/python/ops/standard_ops.py +++ b/tensorflow/python/ops/standard_ops.py @@ -60,6 +60,7 @@ from tensorflow.python.ops.io_ops import * from tensorflow.python.ops.linalg_ops import * from tensorflow.python.ops.logging_ops import Print from tensorflow.python.ops.logging_ops import get_summary_op +from tensorflow.python.ops.logging_ops import timestamp from tensorflow.python.ops.lookup_ops import initialize_all_tables from tensorflow.python.ops.lookup_ops import tables_initializer from tensorflow.python.ops.manip_ops import * @@ -232,7 +233,7 @@ _allowed_symbols_clip_ops = [ "global_norm", ] -_allowed_symbols_image_ops = [ +_allowed_symbols_logging_ops = [ # Documented in training.py. # We are not importing training.py to avoid complex dependencies. "audio_summary", @@ -262,8 +263,8 @@ _allowed_symbols = (_allowed_symbols_array_ops + _allowed_symbols_clip_ops + _allowed_symbols_control_flow_ops + _allowed_symbols_functional_ops + - _allowed_symbols_image_ops + _allowed_symbols_gradients + + _allowed_symbols_logging_ops + _allowed_symbols_math_ops + _allowed_symbols_variable_scope_ops + _allowed_symbols_misc + diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py index f00213eb88dce8e7bf73264a54780a704b4c4b18..6c0a090d16bb328de40f02edf9865a0e0a62d385 100644 --- a/tensorflow/python/ops/state_ops.py +++ b/tensorflow/python/ops/state_ops.py @@ -353,11 +353,9 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None): if ref.dtype._is_ref_dtype: return gen_state_ops.scatter_update(ref, indices, updates, use_locking=use_locking, name=name) - with ops.control_dependencies( - [gen_resource_variable_ops.resource_scatter_update( - ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), - name=name)]): - return ref.read_value() + return ref._lazy_read(gen_resource_variable_ops.resource_scatter_update( # pylint: disable=protected-access + ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), + name=name)) @tf_export("scatter_nd_update") diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index 806fdd3da7aa6de01b7cd4d9d36dbf43f6139db6..424582b348d87d8a5b043ec9b771d8f2768a5994 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -557,6 +557,7 @@ class EagerTemplate(Template): # is created in __call__. variable_scope_name = None self._template_store = _EagerTemplateVariableStore(variable_scope_name) + self._variable_scope_context_manager = None def _call_func(self, args, kwargs): try: @@ -611,8 +612,12 @@ class EagerTemplate(Template): # the variable scope is opened in order to ensure that templates nested at # the same level correctly uniquify lower variable scope names. if self._variable_scope: - with variable_scope.variable_scope( - self._variable_scope, reuse=variable_scope.AUTO_REUSE): + # Create a cache for the variable scope context manager the first time + # around so that we don't have to keep recreating it. + if not self._variable_scope_context_manager: + self._variable_scope_context_manager = variable_scope.variable_scope( + self._variable_scope, reuse=variable_scope.AUTO_REUSE) + with self._variable_scope_context_manager: with self._template_store.as_default(): result = self._call_func(args, kwargs) return result diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 5cdf03509e3c427deec7e26345059211001e2131..3c08870146e447d84d4a5f620cbead633d94751f 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -653,7 +653,7 @@ class _EagerTensorArray(object): if len(tensors) > len(self._tensor_array) and not self._dynamic_size: raise ValueError( "Cannot unstack %d tensors into a TensorArray of static size %d" % - (len(tensors), len(self._tensors))) + (len(tensors), len(self._tensor_array))) ta = self._identity_without_array() ta._implementation._tensor_array = tensors # pylint: disable=protected-access return ta diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py index 19e3298e4019f94132db25ab0dae5ed458bfbeb3..d382683858be5d91755ef1a15ebbc6ae2287f8a7 100644 --- a/tensorflow/python/ops/variables.py +++ b/tensorflow/python/ops/variables.py @@ -29,6 +29,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.util import compat from tensorflow.python.util import tf_should_use from tensorflow.python.util.deprecation import deprecated @@ -36,7 +37,7 @@ from tensorflow.python.util.tf_export import tf_export @tf_export("Variable") -class Variable(object): +class Variable(checkpointable.CheckpointableBase): """See the @{$variables$Variables How To} for a high level overview. A variable maintains state in the graph across calls to `run()`. You add a @@ -306,6 +307,11 @@ class Variable(object): if constraint is not None and not callable(constraint): raise ValueError("The `constraint` argument must be a callable.") + if isinstance(initial_value, checkpointable.CheckpointInitialValue): + self._maybe_initialize_checkpointable() + self._update_uid = initial_value.checkpoint_position.restore_uid + initial_value = initial_value.wrapped_value + if trainable and ops.GraphKeys.TRAINABLE_VARIABLES not in collections: collections = list(collections) + [ops.GraphKeys.TRAINABLE_VARIABLES] with ops.init_scope(): @@ -786,6 +792,10 @@ class Variable(object): setattr(Variable, operator, _run_op) + def _gather_saveables_for_checkpoint(self): + """For implementing `Checkpointable`. This object is saveable on its own.""" + return {checkpointable.VARIABLE_VALUE_KEY: self} + def _try_guard_against_uninitialized_dependencies(self, initial_value): """Attempt to guard against dependencies on uninitialized variables. diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i index 50f481d29e9d39bd12741b5f9e02b7201336134d..7ab0db526881109765adf83749bd01e4d543e5b2 100644 --- a/tensorflow/python/pywrap_tfe.i +++ b/tensorflow/python/pywrap_tfe.i @@ -29,9 +29,11 @@ limitations under the License. %rename("%s") TFE_OpNameGetAttrType; %rename("%s") TFE_Py_InitEagerTensor; %rename("%s") TFE_Py_RegisterExceptionClass; +%rename("%s") TFE_Py_RegisterBackwardFunctionGetter; %rename("%s") TFE_Py_RegisterFallbackExceptionClass; %rename("%s") TFE_Py_Execute; %rename("%s") TFE_Py_FastPathExecute; +%rename("%s") TFE_Py_RecordGradient; %rename("%s") TFE_Py_UID; %rename("%s") TFE_Py_TapeSetNew; %rename("%s") TFE_Py_TapeSetRemove; diff --git a/tensorflow/python/tools/freeze_graph.py b/tensorflow/python/tools/freeze_graph.py index 074b8e71326fa376ac10e89ee4b01d3ddc41adc6..a52f325ddbcd90ad011c1c056965912b96f27aaa 100644 --- a/tensorflow/python/tools/freeze_graph.py +++ b/tensorflow/python/tools/freeze_graph.py @@ -109,7 +109,7 @@ def freeze_graph_with_def_protos(input_graph_def, input_meta_graph_def, clear_devices=True) restorer.restore(sess, input_checkpoint) if initializer_nodes: - sess.run(initializer_nodes.split(",")) + sess.run(initializer_nodes.replace(" ", "").split(",")) elif input_saved_model_dir: if saved_model_tags is None: saved_model_tags = [] @@ -130,25 +130,27 @@ def freeze_graph_with_def_protos(input_graph_def, var_list=var_list, write_version=checkpoint_version) saver.restore(sess, input_checkpoint) if initializer_nodes: - sess.run(initializer_nodes.split(",")) + sess.run(initializer_nodes.replace(" ", "").split(",")) - variable_names_whitelist = (variable_names_whitelist.split(",") - if variable_names_whitelist else None) - variable_names_blacklist = (variable_names_blacklist.split(",") - if variable_names_blacklist else None) + variable_names_whitelist = ( + variable_names_whitelist.replace(" ", "").split(",") + if variable_names_whitelist else None) + variable_names_blacklist = ( + variable_names_blacklist.replace(" ", "").split(",") + if variable_names_blacklist else None) if input_meta_graph_def: output_graph_def = graph_util.convert_variables_to_constants( sess, input_meta_graph_def.graph_def, - output_node_names.split(","), + output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) else: output_graph_def = graph_util.convert_variables_to_constants( sess, input_graph_def, - output_node_names.split(","), + output_node_names.replace(" ", "").split(","), variable_names_whitelist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) @@ -250,7 +252,7 @@ def freeze_graph(input_graph, variable_names_blacklist, input_meta_graph_def, input_saved_model_dir, - saved_model_tags.split(","), + saved_model_tags.replace(" ", "").split(","), checkpoint_version=checkpoint_version) diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py index 33f6debbcbecb652774c776be54323bbaa824822..b0e9e3e5ed2117937bbd275784c44aebd2ea2515 100644 --- a/tensorflow/python/tools/saved_model_cli.py +++ b/tensorflow/python/tools/saved_model_cli.py @@ -115,7 +115,7 @@ def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def, signature_def_key).outputs -def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key): +def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, indent=0): """Prints input and output TensorInfos. Prints the details of input and output TensorInfos for the SignatureDef mapped @@ -126,6 +126,7 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key): tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by ','. For tag-set contains multiple tags, all tags must be passed in. signature_def_key: A SignatureDef key string. + indent: How far (in increments of 2 spaces) to indent each line of output. """ meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set) @@ -134,29 +135,39 @@ def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key): outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def( meta_graph_def, signature_def_key) - print('The given SavedModel SignatureDef contains the following input(s):') + indent_str = " " * indent + def in_print(s): + print(indent_str + s) + + in_print('The given SavedModel SignatureDef contains the following input(s):') for input_key, input_tensor in sorted(inputs_tensor_info.items()): - print('inputs[\'%s\'] tensor_info:' % input_key) - _print_tensor_info(input_tensor) + in_print(' inputs[\'%s\'] tensor_info:' % input_key) + _print_tensor_info(input_tensor, indent+1) - print('The given SavedModel SignatureDef contains the following output(s):') + in_print('The given SavedModel SignatureDef contains the following ' + 'output(s):') for output_key, output_tensor in sorted(outputs_tensor_info.items()): - print('outputs[\'%s\'] tensor_info:' % output_key) - _print_tensor_info(output_tensor) + in_print(' outputs[\'%s\'] tensor_info:' % output_key) + _print_tensor_info(output_tensor, indent+1) - print('Method name is: %s' % - meta_graph_def.signature_def[signature_def_key].method_name) + in_print('Method name is: %s' % + meta_graph_def.signature_def[signature_def_key].method_name) -def _print_tensor_info(tensor_info): +def _print_tensor_info(tensor_info, indent=0): """Prints details of the given tensor_info. Args: tensor_info: TensorInfo object to be printed. + indent: How far (in increments of 2 spaces) to indent each line output """ - print(' dtype: ' + - {value: key - for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype]) + indent_str = " " * indent + def in_print(s): + print(indent_str + s) + + in_print(' dtype: ' + + {value: key + for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype]) # Display shape as tuple. if tensor_info.tensor_shape.unknown_rank: shape = 'unknown_rank' @@ -164,8 +175,8 @@ def _print_tensor_info(tensor_info): dims = [str(dim.size) for dim in tensor_info.tensor_shape.dim] shape = ', '.join(dims) shape = '(' + shape + ')' - print(' shape: ' + shape) - print(' name: ' + tensor_info.name) + in_print(' shape: ' + shape) + in_print(' name: ' + tensor_info.name) def _show_all(saved_model_dir): @@ -186,7 +197,8 @@ def _show_all(saved_model_dir): signature_def_map = get_signature_def_map(saved_model_dir, tag_set) for signature_def_key in sorted(signature_def_map.keys()): print('\nsignature_def[\'' + signature_def_key + '\']:') - _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key) + _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key, + indent=1) def get_meta_graph_def(saved_model_dir, tag_set): @@ -614,19 +626,19 @@ def create_parser(): show_msg = ( 'Usage examples:\n' 'To show all tag-sets in a SavedModel:\n' - '$saved_model_cli show --dir /tmp/saved_model\n' + '$saved_model_cli show --dir /tmp/saved_model\n\n' 'To show all available SignatureDef keys in a ' 'MetaGraphDef specified by its tag-set:\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n\n' 'For a MetaGraphDef with multiple tags in the tag-set, all tags must be ' 'passed in, separated by \';\':\n' '$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n' 'To show all inputs and outputs TensorInfo for a specific' ' SignatureDef specified by the SignatureDef key in a' ' MetaGraph.\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve ' - '--signature_def serving_default\n\n' - 'To show all available information in the SavedModel\n:' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve' + ' --signature_def serving_default\n\n' + 'To show all available information in the SavedModel:\n' '$saved_model_cli show --dir /tmp/saved_model --all') parser_show = subparsers.add_parser( 'show', @@ -658,12 +670,14 @@ def create_parser(): run_msg = ('Usage example:\n' 'To run input tensors from files through a MetaGraphDef and save' ' the output tensors to files:\n' - '$saved_model_cli show --dir /tmp/saved_model --tag_set serve ' - '--signature_def serving_default ' - '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy ' - '--input_exprs \'input3_key=np.ones(2)\' --input_examples ' - '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' ' - '--outdir=/out\n\n' + '$saved_model_cli show --dir /tmp/saved_model --tag_set serve \\\n' + ' --signature_def serving_default \\\n' + ' --inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy ' + '\\\n' + ' --input_exprs \'input3_key=np.ones(2)\' \\\n' + ' --input_examples ' + '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' \\\n' + ' --outdir=/out\n\n' 'For more information about input file format, please see:\n' 'https://www.tensorflow.org/programmers_guide/saved_model_cli\n') parser_run = subparsers.add_parser( diff --git a/tensorflow/python/tools/saved_model_cli_test.py b/tensorflow/python/tools/saved_model_cli_test.py index d6cbc49ba1e08a6b808b228fb8d69fc14f36e3d2..f99c8448458078935fda477c6e4e15dde8d7d4ab 100644 --- a/tensorflow/python/tools/saved_model_cli_test.py +++ b/tensorflow/python/tools/saved_model_cli_test.py @@ -61,83 +61,84 @@ class SavedModelCLITestCase(test.TestCase): exp_out = """MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs: signature_def['classify_x2_to_y3']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x2:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['scores'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y3:0 -Method name is: tensorflow/serving/classify + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 + Method name is: tensorflow/serving/classify signature_def['classify_x_to_y']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_STRING - shape: unknown_rank - name: tf_example:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['scores'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/classify + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_STRING + shape: unknown_rank + name: tf_example:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['scores'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/classify signature_def['regress_x2_to_y3']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x2:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['outputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y3:0 -Method name is: tensorflow/serving/regress + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x2:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['outputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y3:0 + Method name is: tensorflow/serving/regress signature_def['regress_x_to_y']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_STRING - shape: unknown_rank - name: tf_example:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['outputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/regress + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_STRING + shape: unknown_rank + name: tf_example:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['outputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/regress signature_def['regress_x_to_y2']: -The given SavedModel SignatureDef contains the following input(s): -inputs['inputs'] tensor_info: - dtype: DT_STRING - shape: unknown_rank - name: tf_example:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['outputs'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y2:0 -Method name is: tensorflow/serving/regress + The given SavedModel SignatureDef contains the following input(s): + inputs['inputs'] tensor_info: + dtype: DT_STRING + shape: unknown_rank + name: tf_example:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['outputs'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y2:0 + Method name is: tensorflow/serving/regress signature_def['serving_default']: -The given SavedModel SignatureDef contains the following input(s): -inputs['x'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: x:0 -The given SavedModel SignatureDef contains the following output(s): -outputs['y'] tensor_info: - dtype: DT_FLOAT - shape: (-1, 1) - name: y:0 -Method name is: tensorflow/serving/predict""" + The given SavedModel SignatureDef contains the following input(s): + inputs['x'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: x:0 + The given SavedModel SignatureDef contains the following output(s): + outputs['y'] tensor_info: + dtype: DT_FLOAT + shape: (-1, 1) + name: y:0 + Method name is: tensorflow/serving/predict""" # pylint: enable=line-too-long + self.maxDiff = None # Produce a useful error msg if the comparison fails self.assertMultiLineEqual(output, exp_out) self.assertEqual(err.getvalue().strip(), '') @@ -193,11 +194,11 @@ Method name is: tensorflow/serving/predict""" output = out.getvalue().strip() expected_output = ( 'The given SavedModel SignatureDef contains the following input(s):\n' - 'inputs[\'x\'] tensor_info:\n' - ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: x:0\n' + ' inputs[\'x\'] tensor_info:\n' + ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: x:0\n' 'The given SavedModel SignatureDef contains the following output(s):\n' - 'outputs[\'y\'] tensor_info:\n' - ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: y:0\n' + ' outputs[\'y\'] tensor_info:\n' + ' dtype: DT_FLOAT\n shape: (-1, 1)\n name: y:0\n' 'Method name is: tensorflow/serving/predict') self.assertEqual(output, expected_output) self.assertEqual(err.getvalue().strip(), '') diff --git a/tensorflow/python/training/checkpoint_utils.py b/tensorflow/python/training/checkpoint_utils.py index fa3de6fad27b6cc773f9f2e86e9f95395eb7c285..8384d0ae943c7336ebe20093f7c0c0b89012c129 100644 --- a/tensorflow/python/training/checkpoint_utils.py +++ b/tensorflow/python/training/checkpoint_utils.py @@ -289,10 +289,16 @@ def _set_checkpoint_initializer(variable, name: Name of the operation. """ base_type = variable.dtype.base_dtype - with ops.colocate_with(variable): + # Do not colocate with variable since RestoreV2 op only runs on CPU and + # colocation will force variable (and other ops that colocate with variable) + # to be on CPU as well. It is okay to place the variable's initializer op on + # CPU since it will only be run once at the start. + with ops.device(variable.device), ops.device("/cpu:0"): restore_op = io_ops.restore_v2( ckpt_file, [tensor_name], [slice_spec], [base_type], name=name)[0] variable._initializer_op = state_ops.assign(variable, restore_op) # pylint:disable=protected-access + restore_op.set_shape(variable.shape) + variable._initial_value = restore_op # pylint:disable=protected-access def _set_variable_or_list_initializer(variable_or_list, ckpt_file, diff --git a/tensorflow/python/training/checkpoint_utils_test.py b/tensorflow/python/training/checkpoint_utils_test.py index cd17faa040d5b85263b54bc53100b18f736a12e0..f564871315f2981795accb438201b0131a49c1cb 100644 --- a/tensorflow/python/training/checkpoint_utils_test.py +++ b/tensorflow/python/training/checkpoint_utils_test.py @@ -145,6 +145,36 @@ class CheckpointsTest(test.TestCase): # Check that tensors are not explicitly in the graph. self.assertLess(len(str(session.graph.as_graph_def())), 29000) + def testInitialValueComesFromCheckpoint(self): + checkpoint_dir = self.get_temp_dir() + with self.test_session() as session: + v1, _, _, _ = _create_checkpoints(session, checkpoint_dir) + + # New graph and session. + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as session: + with variable_scope.variable_scope( + "some_scope", initializer=init_ops.zeros_initializer()): + my1 = variable_scope.get_variable("my1", [1, 10]) + + # At this point, my1.initialized_value() will add ops that reference + # the zeros initializer of my1. + before = variables.Variable(my1.initialized_value(), name="before") + + checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"var1": my1}) + + # At this point, my1.initialized_value() will add ops that reference + # the newly set initializer of my1. + after = variables.Variable(my1.initialized_value(), name="after") + + session.run(variables.global_variables_initializer()) + self.assertAllEqual(session.run(my1), v1) + self.assertAllEqual(session.run(my1.initialized_value()), v1) + self.assertAllClose(session.run(before), [[0.0] * 10]) + self.assertAllClose(session.run(after), v1) + with self.assertRaises(AssertionError): + self.assertAllClose(session.run(before), session.run(after)) + def testInitWithScopeDoesNotCaptureSuffixes(self): checkpoint_dir = self.get_temp_dir() with self.test_session() as session: @@ -176,7 +206,9 @@ class CheckpointsTest(test.TestCase): checkpoint_utils.init_from_checkpoint(checkpoint_dir, {"useful_scope/": "useful_scope/"}) - self.assertEqual(my4._initializer_op.op.inputs[1].device, "/job:ps") + # initializer runs on the same task but always on CPU. + self.assertEqual(my4._initializer_op.op.inputs[1].device, + "/job:ps/device:CPU:0") def testInitFromRootCheckpoint(self): checkpoint_dir = self.get_temp_dir() diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..11caa761aec5d631d87a91ec876e0b5032ffdc5b --- /dev/null +++ b/tensorflow/python/training/checkpointable.py @@ -0,0 +1,588 @@ +"""An object-local variable management scheme.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_io_ops as io_ops +from tensorflow.python.util import nest + +# A key indicating a variable's value in an object's checkpointed Tensors +# (Checkpointable._gather_saveables_for_checkpoint). If this is the only key and +# the object has no dependencies, then its value may be restored on object +# creation (avoiding double assignment when executing eagerly). +VARIABLE_VALUE_KEY = "VARIABLE_VALUE" + +_CheckpointableReference = collections.namedtuple( + "_CheckpointableReference", + [ + # The local name for this dependency. + "name", + # The Checkpointable object being referenced. + "ref" + ]) + + +class CheckpointInitialValue(ops.Tensor): + """Tensor wrapper for managing update UIDs in `Variables`. + + When supplied as an initial value, objects of this type let a `Variable` + (`Variable`, `ResourceVariable`, etc.) know the UID of the restore the initial + value came from. This allows deferred restorations to be sequenced in the + order the user specified them, and lets us fall back on assignment if an + initial value is not set (e.g. due to a custom getter interfering). + + See comments in _add_variable_with_custom_getter for more information about + how `CheckpointInitialValue` is used. + """ + + def __init__(self, checkpoint_position, shape=None): + self.wrapped_value = checkpoint_position.value_tensors()[ + VARIABLE_VALUE_KEY] + if shape: + # We need to set the static shape information on the initializer if + # possible so we don't get a variable with an unknown shape. + self.wrapped_value.set_shape(shape) + self._checkpoint_position = checkpoint_position + + @property + def __class__(self): + return (self.wrapped_value.__class__, CheckpointInitialValue) + + def __getattr__(self, attr): + try: + return getattr(self.wrapped_value, attr) + except AttributeError: + return self.__getattribute__(attr) + + @property + def checkpoint_position(self): + return self._checkpoint_position + + +class _CheckpointPosition(object): + """Indicates a position within a `_Checkpoint`.""" + + def __init__(self, checkpoint, proto_id): + """Specify an object within a checkpoint. + + Args: + checkpoint: A _Checkpoint object. + proto_id: The index of this object in CheckpointableObjectGraph.nodes. + """ + self._checkpoint = checkpoint + self._proto_id = proto_id + + def restore(self, checkpointable): + """Restore this value into `checkpointable`.""" + if self.bind_object(checkpointable): + # This object's correspondence with a checkpointed object is new, so + # process deferred restorations for it and its dependencies. + restore_ops = checkpointable._restore_from_checkpoint_position(self) # pylint: disable=protected-access + if restore_ops: + self._checkpoint.restore_ops.extend(restore_ops) + + def bind_object(self, checkpointable): + """Set a checkpoint<->object correspondence and process slot variables. + + Args: + checkpointable: The object to record a correspondence for. + Returns: + True if this is a new assignment, False if this object has already been + mapped to a checkpointed `Object` proto. + Raises: + AssertionError: If another object is already bound to the `Object` proto. + """ + checkpoint = self.checkpoint + current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None) + if current_assignment is None: + checkpoint.object_by_proto_id[self._proto_id] = checkpointable + for deferred_slot_restoration in ( + checkpoint.deferred_slot_restorations.pop(self._proto_id, ())): + checkpointable._create_or_restore_slot_variable( # pylint: disable=protected-access + slot_variable_position=_CheckpointPosition( + checkpoint=checkpoint, + proto_id=deferred_slot_restoration.slot_variable_id), + variable=deferred_slot_restoration.original_variable, + slot_name=deferred_slot_restoration.slot_name) + for slot_restoration in checkpoint.slot_restorations.pop( + self._proto_id, ()): + optimizer_object = checkpoint.object_by_proto_id.get( + slot_restoration.optimizer_id, None) + if optimizer_object is None: + # The optimizer has not yet been created or tracked. Record in the + # checkpoint that the slot variables need to be restored when it is. + checkpoint.deferred_slot_restorations.setdefault( + slot_restoration.optimizer_id, []).append( + _DeferredSlotVariableRestoration( + original_variable=checkpointable, + slot_variable_id=slot_restoration.slot_variable_id, + slot_name=slot_restoration.slot_name)) + else: + optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access + slot_variable_position=_CheckpointPosition( + checkpoint=checkpoint, + proto_id=slot_restoration.slot_variable_id), + variable=checkpointable, + slot_name=slot_restoration.slot_name) + return True # New assignment + else: + # The object was already mapped for this checkpoint load, which means + # we don't need to do anything besides check that the mapping is + # consistent (if the dependency DAG is not a tree then there are + # multiple paths to the same object). + if current_assignment is not checkpointable: + raise AssertionError( + ("Unable to load the checkpoint into this object graph. Either " + "the Checkpointable object references in the Python program " + "have changed in an incompatible way, or the checkpoint was " + "generated in an incompatible program.\n\nTwo checkpoint " + "references resolved to different objects (%s and %s).") + % (current_assignment, checkpointable)) + return False # Not a new assignment + + def is_simple_variable(self): + """Determine whether this value is restorable with a Tensor initializer.""" + attributes = self.object_proto.attributes + return (len(attributes) == 1 + and attributes[0].name == VARIABLE_VALUE_KEY + and not self.object_proto.children) + + def value_tensors(self): + """Create value `Tensor`s for this object's attributes. + + Does not require that the Python object has been created. Used for + restore-on-create when executing eagerly. + + Returns: + A dictionary mapping from object attribute names to `Tensor`s. + """ + value_tensors = {} + for serialized_tensor in self.object_proto.attributes: + checkpoint_key = serialized_tensor.checkpoint_key + dtype = self._checkpoint.dtype_map[checkpoint_key] + base_type = dtype.base_dtype + with ops.init_scope(): + value, = io_ops.restore_v2( + prefix=self._checkpoint.save_path, + tensor_names=[checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="%s_checkpoint_read" % (serialized_tensor.name,)) + value_tensors[serialized_tensor.name] = value + return value_tensors + + def restore_ops(self): + """Create or fetch restore ops for this object's attributes. + + Requires that the `Checkpointable` Python object has been bound to an object + ID in the checkpoint. + + Returns: + A list of operations when graph building, or an empty list when executing + eagerly. + """ + saveables = self.checkpointable._gather_saveables_for_checkpoint() # pylint: disable=protected-access + # Name saveables based on the name this object had when it was checkpointed. + named_saveables = {} + restore_ops = [] + in_graph_mode = context.in_graph_mode() + for serialized_tensor in self.object_proto.attributes: + saveable_object = saveables.get(serialized_tensor.name, None) + if saveable_object is None: + # Purposefully does not throw an exception if attributes have been added + # or deleted. Stores unused attributes so an exception can be raised if + # the user decides to check that everything in the checkpoint was + # loaded. + self._checkpoint.unused_attributes.setdefault( + self.checkpointable, []).append(serialized_tensor.name) + continue + if in_graph_mode: + existing_ops = self._checkpoint.restore_ops_by_name.get( + serialized_tensor.name, None) + else: + existing_ops = None + if existing_ops is None: + named_saveables[serialized_tensor.checkpoint_key] = saveable_object + if named_saveables: + validated_saveables = ( + self._checkpoint.builder._ValidateAndSliceInputs(named_saveables)) # pylint: disable=protected-access + validated_names = set(saveable.name for saveable in validated_saveables) + if set(named_saveables.keys()) != validated_names: + raise AssertionError( + ("Saveable keys changed when validating. Got back %s, was " + "expecting %s") % (named_saveables.keys(), validated_names)) + all_tensors = self._checkpoint.builder.bulk_restore( + filename_tensor=self._checkpoint.save_path, + saveables=validated_saveables, preferred_shard=-1, + restore_sequentially=False) + saveable_index = 0 + for saveable in validated_saveables: + num_specs = len(saveable.specs) + saveable_tensors = all_tensors[ + saveable_index:saveable_index + num_specs] + saveable_index += num_specs + restore_op = saveable.restore(saveable_tensors, restored_shapes=None) + if in_graph_mode: + assert saveable.name not in self._checkpoint.restore_ops_by_name + self._checkpoint.restore_ops_by_name[saveable.name] = restore_op + restore_ops.append(restore_op) + return restore_ops + + @property + def checkpoint(self): + return self._checkpoint + + @property + def checkpointable(self): + return self._checkpoint.object_by_proto_id[self._proto_id] + + @property + def object_proto(self): + return self._checkpoint.object_graph_proto.nodes[self._proto_id] + + @property + def restore_uid(self): + return self._checkpoint.restore_uid + + def __repr__(self): + return repr(self.object_proto) + + +_DeferredSlotVariableRestoration = collections.namedtuple( + "_DeferredSlotVariableRestoration", + [ + "original_variable", + "slot_variable_id", + "slot_name", + ] +) + +_SlotVariableRestoration = collections.namedtuple( + "_SlotVariableRestoration", + [ + # The checkpoint proto id of the optimizer object. + "optimizer_id", + # The checkpoint proto id of the slot variable. + "slot_variable_id", + "slot_name", + ]) + + +class CheckpointableBase(object): + """Base class for `Checkpointable` objects without automatic dependencies. + + This class has no __setattr__ override for performance reasons. Dependencies + must be added explicitly. Unless attribute assignment is performance-critical, + use `Checkpointable` instead. Use `CheckpointableBase` for `isinstance` + checks. + """ + + def _maybe_initialize_checkpointable(self): + """Initialize dependency management. + + Not __init__, since most objects will forget to call it. + """ + if hasattr(self, "_checkpoint_dependencies"): + # __init__ already called. This check means that we don't need + # Checkpointable.__init__() in the constructor of every TensorFlow object. + return + # A list of _CheckpointableReference objects. + self._checkpoint_dependencies = [] + # Maps names -> Checkpointable objects + self._dependency_names = {} + # Restorations for other Checkpointable objects on which this object may + # eventually depend. + self._deferred_dependencies = {} # local name -> _CheckpointPosition list + # The UID of the highest assignment to this object. Used to ensure that the + # last requested assignment determines the final value of an object. + if hasattr(self, "_update_uid"): + raise AssertionError( + "Internal error: the object had an update UID set before its " + "initialization code was run.") + self._update_uid = -1 + + def _add_variable_with_custom_getter( + self, name, shape=None, dtype=dtypes.float32, + initializer=None, getter=None, **kwargs_for_getter): + """Restore-on-create for a variable be saved with this `Checkpointable`. + + If the user has requested that this object or another `Checkpointable` which + depends on this object be restored from a checkpoint (deferred loading + before variable object creation), `initializer` may be ignored and the value + from the checkpoint used instead. + + Args: + name: A name for the variable. Must be unique within this object. + shape: The shape of the variable. + dtype: The data type of the variable. + + initializer: The initializer to use. Ignored if there is a deferred + restoration left over from a call to + `_restore_from_checkpoint_position`. + + getter: The getter to wrap which actually fetches the variable. + **kwargs_for_getter: Passed to the getter. + + Returns: + The new variable object. + + Raises: + ValueError: If the variable name is not unique. + """ + self._maybe_initialize_checkpointable() + if name in self._dependency_names: + raise ValueError( + ("A variable named '%s' already exists in this Checkpointable, but " + "Checkpointable._add_variable called to create another with " + "that name. Variable names must be unique within a Checkpointable " + "object.") % (name,)) + if context.in_eager_mode(): + # If this is a variable with a single Tensor stored in the checkpoint, we + # can set that value as an initializer rather than initializing and then + # assigning (when executing eagerly). This call returns None if there is + # nothing to restore. + checkpoint_initializer = self._preload_simple_restoration( + name=name, shape=shape) + else: + checkpoint_initializer = None + if (checkpoint_initializer is not None + and not ( + isinstance(initializer, CheckpointInitialValue) + and initializer.restore_uid > checkpoint_initializer.restore_uid)): + # If multiple Checkpointable objects are "creating" the same variable via + # the magic of custom getters, the one with the highest restore UID (the + # one called last) has to make the final initializer. If another custom + # getter interrupts this process by overwriting the initializer, then + # we'll catch that when we call _track_checkpointable. So this is "best + # effort" to set the initializer with the highest restore UID. + initializer = checkpoint_initializer + shape = None + + new_variable = getter( + name=name, shape=shape, dtype=dtype, initializer=initializer, + **kwargs_for_getter) + + # If we set an initializer and the variable processed it, tracking will not + # assign again. It will add this variable to our dependencies, and if there + # is a non-trivial restoration queued, it will handle that. This also + # handles slot variables. + return self._track_checkpointable(new_variable, name=name) + + def _preload_simple_restoration(self, name, shape): + """Return a dependency's value for restore-on-create. + + Note the restoration is not deleted; if for some reason preload is called + and then not assigned to the variable (for example because a custom getter + overrides the initializer), the assignment will still happen once the + variable is tracked (determined based on checkpoint.restore_uid). + + Args: + name: The object-local name of the dependency holding the variable's + value. + shape: The shape of the variable being loaded into. + Returns: + An callable for use as a variable's initializer/initial_value, or None if + one should not be set (either because there was no variable with this name + in the checkpoint or because it needs more complex deserialization). Any + non-trivial deserialization will happen when the variable object is + tracked. + """ + deferred_dependencies_list = self._deferred_dependencies.get(name, ()) + if not deferred_dependencies_list: + # Nothing to do; we don't have a restore for this dependency queued up. + return + for checkpoint_position in deferred_dependencies_list: + if not checkpoint_position.is_simple_variable(): + # If _any_ pending restoration is too complicated to fit in an + # initializer (because it has dependencies, or because there are + # multiple Tensors to restore), bail and let the general tracking code + # handle it. + return None + checkpoint_position = max( + deferred_dependencies_list, + key=lambda restore: restore.checkpoint.restore_uid) + return CheckpointInitialValue( + checkpoint_position=checkpoint_position, shape=shape) + + def _track_checkpointable(self, checkpointable, name, overwrite=False): + """Declare a dependency on another `Checkpointable` object. + + Indicates that checkpoints for this object should include variables from + `checkpointable`. + + Variables in a checkpoint are mapped to `Checkpointable`s based on the names + provided when the checkpoint was written. To avoid breaking existing + checkpoints when modifying a class, neither variable names nor dependency + names (the names passed to `_track_checkpointable`) may change. + + Args: + checkpointable: A `Checkpointable` which this object depends on. + name: A local name for `checkpointable`, used for loading checkpoints into + the correct objects. + overwrite: Boolean, whether silently replacing dependencies is OK. Used + for __setattr__, where throwing an error on attribute reassignment would + be inappropriate. + + Returns: + `checkpointable`, for convenience when declaring a dependency and + assigning to a member variable in one statement. + + Raises: + TypeError: If `checkpointable` does not inherit from `Checkpointable`. + ValueError: If another object is already tracked by this name. + """ + self._maybe_initialize_checkpointable() + if not isinstance(checkpointable, CheckpointableBase): + raise TypeError( + ("Checkpointable._track_checkpointable() passed type %s, not a " + "Checkpointable.") % (type(checkpointable),)) + new_reference = _CheckpointableReference(name=name, ref=checkpointable) + if (name in self._dependency_names + and self._dependency_names[name] is not checkpointable): + if not overwrite: + raise ValueError( + ("Called Checkpointable._track_checkpointable() with name='%s', " + "but a Checkpointable with this name is already declared as a " + "dependency. Names must be unique (or overwrite=True).") % (name,)) + # This is a weird thing to do, but we're not going to stop people from + # using __setattr__. + for index, (old_name, _) in enumerate(self._checkpoint_dependencies): + if name == old_name: + self._checkpoint_dependencies[index] = new_reference + else: + self._checkpoint_dependencies.append(new_reference) + + self._dependency_names[name] = checkpointable + deferred_dependency_list = self._deferred_dependencies.pop(name, None) + if deferred_dependency_list is not None: + for checkpoint_position in deferred_dependency_list: + checkpoint_position.restore(checkpointable=checkpointable) + return checkpointable + + def _restore_from_checkpoint_position(self, checkpoint_position): + """Restore this object and its dependencies (may be deferred).""" + # Attempt a breadth-first traversal, since presumably the user has more + # control over shorter paths. If we don't have all of the dependencies at + # this point, the end result is not breadth-first (since other deferred + # traversals will happen later). + visit_queue = collections.deque([checkpoint_position]) + restore_ops = [] + while visit_queue: + current_position = visit_queue.popleft() + restore_ops.extend(nest.flatten( + current_position.checkpointable # pylint: disable=protected-access + ._single_restoration_from_checkpoint_position( + checkpoint_position=current_position, + visit_queue=visit_queue))) + return restore_ops + + def _single_restoration_from_checkpoint_position( + self, checkpoint_position, visit_queue): + """Restore this object, and either queue its dependencies or defer them.""" + self._maybe_initialize_checkpointable() + checkpoint = checkpoint_position.checkpoint + # If the UID of this restore is lower than our current update UID, we don't + # need to actually restore the object. However, we should pass the + # restoration on to our dependencies. + if checkpoint.restore_uid > self._update_uid: + restore_ops = checkpoint_position.restore_ops() + # TODO(allenl): Get a list of feeds for saving Python state + self._update_uid = checkpoint.restore_uid + else: + restore_ops = () + for child in checkpoint_position.object_proto.children: + child_position = _CheckpointPosition( + checkpoint=checkpoint, + proto_id=child.node_id) + local_object = self._dependency_names.get(child.local_name, None) + if local_object is None: + # We don't yet have a dependency registered with this name. Save it + # in case we do. + self._deferred_dependencies.setdefault(child.local_name, []).append( + child_position) + else: + if child_position.bind_object(checkpointable=local_object): + # This object's correspondence is new, so dependencies need to be + # visited. Delay doing it so that we get a breadth-first dependency + # resolution order (shallowest paths first). The caller is responsible + # for emptying visit_queue. + visit_queue.append(child_position) + return restore_ops + + def _gather_saveables_for_checkpoint(self): + """Returns a dictionary of values to checkpoint with this object. + + Keys in the returned dictionary are local to this object and in a separate + namespace from dependencies. Values may either be `SaveableObject`s or + variables easily converted to `SaveableObject`s (as in `tf.train.Saver`'s + `var_list` constructor argument). + + Returned values must be saved only by this object; if any value may be + shared, it should instead be a dependency. For example, variable objects + save their own values with the key `VARIABLE_VALUE_KEY`, but objects which + reference variables simply add a dependency. + """ + return {} + + +class Checkpointable(CheckpointableBase): + """Manages dependencies on other objects. + + `Checkpointable` objects may have dependencies: other `Checkpointable` objects + which should be saved if the object declaring the dependency is saved. A + correctly saveable program has a dependency graph such that if changing a + global variable affects an object (e.g. changes the behavior of any of its + methods) then there is a chain of dependencies from the influenced object to + the variable. + + Dependency edges have names, and are created implicitly when a + `Checkpointable` object is assigned to an attribute of another + `Checkpointable` object. For example: + + ``` + obj = Checkpointable() + obj.v = ResourceVariable(0.) + ``` + + The `Checkpointable` object `obj` now has a dependency named "v" on a + variable. + + `Checkpointable` objects may specify `Tensor`s to be saved and restored + directly (e.g. a `Variable` indicating how to save itself) rather than through + dependencies on other objects. See + `Checkpointable._gather_saveables_for_checkpoint` for details. + """ + + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax.""" + # Perform the attribute assignment, and potentially call other __setattr__ + # overrides such as that for tf.keras.Model. + super(Checkpointable, self).__setattr__(name, value) + if isinstance(value, CheckpointableBase): + self._track_checkpointable( + value, name=name, + # Allow the user to switch the Checkpointable which is tracked by this + # name, since assigning a new variable to an attribute has + # historically been fine (e.g. Adam did this). + # TODO(allenl): Should this be a warning once Checkpointable save/load + # is usable? + overwrite=True) diff --git a/tensorflow/python/training/checkpointable_test.py b/tensorflow/python/training/checkpointable_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e79acb49758b6a7d69dd084692d434bea808db64 --- /dev/null +++ b/tensorflow/python/training/checkpointable_test.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.platform import test +from tensorflow.python.training import checkpointable + + +class InterfaceTests(test.TestCase): + + def testMultipleAssignment(self): + root = checkpointable.Checkpointable() + root.leaf = checkpointable.Checkpointable() + root.leaf = root.leaf + duplicate_name_dep = checkpointable.Checkpointable() + with self.assertRaises(ValueError): + root._track_checkpointable(duplicate_name_dep, name="leaf") + # No error; we're overriding __setattr__, so we can't really stop people + # from doing this while maintaining backward compatibility. + root.leaf = duplicate_name_dep + root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32123f87ef2d12497077ab0e2f7d4d4cad1ec5dd --- /dev/null +++ b/tensorflow/python/training/checkpointable_utils.py @@ -0,0 +1,78 @@ +"""Utilities for saving/loading Checkpointable objects.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import weakref + +from tensorflow.python.framework import ops +from tensorflow.python.training import checkpointable +from tensorflow.python.training import saver as saver_lib + + +class _Checkpoint(object): + """Holds the status of an object-based checkpoint load.""" + + def __init__(self, object_graph_proto, save_path, dtype_map=None): + """Specify the checkpoint being loaded. + + Args: + object_graph_proto: The CheckpointableObjectGraph protocol buffer + associated with this checkpoint. + save_path: A string `Tensor`. The path to the checkpoint, as returned by + `tf.train.latest_checkpoint`. + dtype_map: When executing eagerly, specifies dtypes for creating slot + variables. None when graph building. + """ + self.builder = saver_lib.BulkSaverBuilder() + self.object_graph_proto = object_graph_proto + self.restore_uid = ops.uid() + # Maps from objects to lists of attributes which were in the checkpoint but + # not loaded into any object, for error checking. + self.unused_attributes = weakref.WeakKeyDictionary() + # Dictionary mapping from an id in the protocol buffer flat array to + # Checkpointable Python objects. This mapping may be deferred if a + # checkpoint is restored before all dependencies have been tracked. Uses + # weak references so that partial restorations don't create reference cycles + # (as objects with deferred dependencies will generally have references to + # this object). + self.object_by_proto_id = weakref.WeakValueDictionary() + self.save_path = save_path + self.dtype_map = dtype_map + # When graph building, contains a list of ops to run to restore objects from + # this checkpoint. + self.restore_ops = [] + self.restore_ops_by_name = {} + # A mapping from optimizer proto ids to lists of slot variables to be + # restored when the optimizer is tracked. Only includes slot variables whose + # regular variables have already been created, and only for optimizer + # objects which have not yet been created/tracked. + self.deferred_slot_restorations = {} + # A mapping from variable proto ids to lists of slot variables to be + # restored when the variable is created/tracked. These get shifted over to + # deferred_slot_restorations if the optimizer hasn't been created when that + # happens. + self.slot_restorations = {} + for node_index, node in enumerate(self.object_graph_proto.nodes): + for slot_reference in node.slot_variables: + # `node` refers to an `Optimizer`, since only these have slot variables. + self.slot_restorations.setdefault( + slot_reference.original_variable_node_id, []).append( + checkpointable._SlotVariableRestoration( # pylint: disable=protected-access + optimizer_id=node_index, + slot_variable_id=slot_reference.slot_variable_node_id, + slot_name=slot_reference.slot_name)) diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py index f05c40b32dcbedddb350ee8a61ad4616c666b86a..454cc3add5c8a5b39385a4a2b48ebe3c5ef2336f 100644 --- a/tensorflow/python/training/optimizer.py +++ b/tensorflow/python/training/optimizer.py @@ -34,6 +34,7 @@ from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables +from tensorflow.python.training import checkpointable from tensorflow.python.training import slot_creator from tensorflow.python.util import nest from tensorflow.python.util.tf_export import tf_export @@ -97,6 +98,9 @@ class _RefVariableProcessor(_OptimizableVariable): def __init__(self, v): self._v = v + def __str__(self): + return "<_RefVariableProcessor(%s)>" % self._v + def target(self): return self._v._ref() # pylint: disable=protected-access @@ -212,7 +216,7 @@ def _get_processor(v): @tf_export("train.Optimizer") -class Optimizer(object): +class Optimizer(checkpointable.Checkpointable): """Base class for optimizers. This class defines the API to add Ops to train a model. You never use this @@ -323,9 +327,18 @@ class Optimizer(object): self._use_locking = use_locking self._name = name # Dictionary of slots. - # {slot_name : { variable_to_train: slot_for_the_variable, ...}, ... } + # {slot_name : + # {_var_key(variable_to_train): slot_for_the_variable, ... }, + # ... } self._slots = {} self._non_slot_dict = {} + # For implementing Checkpointable. Stores information about how to restore + # slot variables which have not yet been created + # (checkpointable._CheckpointPosition objects). + # {slot_name : + # {_var_key(variable_to_train): [checkpoint_position, ... ], ... }, + # ... } + self._deferred_slot_restorations = {} def get_name(self): return self._name @@ -883,7 +896,11 @@ class Optimizer(object): """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - named_slots[_var_key(var)] = slot_creator.create_slot(var, val, op_name) + new_slot_variable = slot_creator.create_slot(var, val, op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[_var_key(var)] = new_slot_variable return named_slots[_var_key(var)] def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype, @@ -904,8 +921,12 @@ class Optimizer(object): """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - named_slots[_var_key(var)] = slot_creator.create_slot_with_initializer( + new_slot_variable = slot_creator.create_slot_with_initializer( var, initializer, shape, dtype, op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[_var_key(var)] = new_slot_variable return named_slots[_var_key(var)] def _zeros_slot(self, var, slot_name, op_name): @@ -922,5 +943,79 @@ class Optimizer(object): """ named_slots = self._slot_dict(slot_name) if _var_key(var) not in named_slots: - named_slots[_var_key(var)] = slot_creator.create_zeros_slot(var, op_name) + new_slot_variable = slot_creator.create_zeros_slot(var, op_name) + self._restore_slot_variable( + slot_name=slot_name, variable=var, + slot_variable=new_slot_variable) + named_slots[_var_key(var)] = new_slot_variable return named_slots[_var_key(var)] + + # -------------- + # For implementing the Checkpointable interface. + # -------------- + + def _restore_slot_variable(self, slot_name, variable, slot_variable): + """Restore a newly created slot variable's value.""" + variable_key = _var_key(variable) + deferred_restorations = self._deferred_slot_restorations.get( + slot_name, {}).pop(variable_key, []) + # Iterate over restores, highest restore UID first to minimize the number + # of assignments. + deferred_restorations.sort(key=lambda position: position.restore_uid, + reverse=True) + for checkpoint_position in deferred_restorations: + checkpoint_position.restore(slot_variable) + + def _create_or_restore_slot_variable( + self, slot_variable_position, slot_name, variable): + """Restore a slot variable's value, possibly creating it. + + Called when a variable which has an associated slot variable is created or + restored. When executing eagerly, we create the slot variable with a + restoring initializer. + + No new variables are created when graph building. Instead, + _restore_slot_variable catches these after normal creation and adds restore + ops to the graph. This method is nonetheless important when graph building + for the case when a slot variable has already been created but `variable` + has just been added to a dependency graph (causing us to realize that the + slot variable needs to be restored). + + Args: + slot_variable_position: A `checkpointable._CheckpointPosition` object + indicating the slot variable `Checkpointable` object to be restored. + slot_name: The name of this `Optimizer`'s slot to restore into. + variable: The variable object this slot is being created for. + """ + named_slots = self._slot_dict(slot_name) + variable_key = _var_key(variable) + slot_variable = named_slots.get(variable_key, None) + if (slot_variable is None + and context.in_eager_mode() + and slot_variable_position.is_simple_variable()): + initializer = checkpointable.CheckpointInitialValue( + checkpoint_position=slot_variable_position) + slot_variable = self._get_or_make_slot( + var=variable, + val=initializer, + slot_name=slot_name, + op_name=self._name) + # Slot variables are not owned by any one object (because we don't want to + # save the slot variable if the optimizer is saved without the non-slot + # variable, or if the non-slot variable is saved without the optimizer; + # it's a dependency hypergraph with edges of the form (optimizer, non-slot + # variable, variable)). So we don't _track_ slot variables anywhere, and + # instead special-case this dependency and otherwise pretend it's a normal + # graph. + if slot_variable is not None: + # If we've either made this slot variable, or if we've pulled out an + # existing slot variable, we should restore it. + slot_variable_position.restore(slot_variable) + else: + # We didn't make the slot variable. Defer restoring until it gets created + # normally. We keep a list rather than the one with the highest restore + # UID in case slot variables have their own dependencies, in which case + # those could differ between restores. + self._deferred_slot_restorations.setdefault( + slot_name, {}).setdefault(variable_key, []).append( + slot_variable_position) diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 0c1c8e664b682f78c69a5244db0773df80b35be7..9afd1e6643f7443bc9bdc5dc2b77ef4402772c38 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -50,6 +50,7 @@ from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import checkpointable from tensorflow.python.training import training_util from tensorflow.python.training.checkpoint_state_pb2 import CheckpointState from tensorflow.python.util import compat @@ -196,8 +197,8 @@ class BaseSaverBuilder(object): # Copy the restored tensor to the variable's device. with ops.device(self._var_device): restored_tensor = array_ops.identity(restored_tensor) - return resource_variable_ops.shape_safe_assign_variable_handle( - self.handle_op, self._var_shape, restored_tensor) + return resource_variable_ops.shape_safe_assign_variable_handle( + self.handle_op, self._var_shape, restored_tensor) def __init__(self, write_version=saver_pb2.SaverDef.V2): self._write_version = write_version @@ -577,6 +578,11 @@ class BaseSaverBuilder(object): names_to_saveables[name].append(var) else: names_to_saveables[name] = [var] + elif (isinstance(var, checkpointable.CheckpointableBase) + and not isinstance(var, variables.Variable)): + names_to_saveables.update( + BaseSaverBuilder.OpListToDict( + list(var._gather_saveables_for_checkpoint().values()))) else: if context.in_graph_mode(): if convert_variable_to_tensor: @@ -1597,9 +1603,9 @@ class Saver(object): [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: - A string: path prefix used for the checkpoint files. If checkpoint - format is V1 and the saver is sharded, this string ends with: - '-?????-of-nnnnn' where 'nnnnn' is the number of shards created. + A string: path prefix used for the checkpoint files. If the saver is + sharded, this string ends with: '-?????-of-nnnnn' where 'nnnnn' + is the number of shards created. If the saver is empty, returns None. Raises: @@ -1749,12 +1755,6 @@ class Saver(object): return if save_path is None: raise ValueError("Can't load save_path when it is None.") - if (os.path.isfile(save_path) and - self._write_version not in ( - saver_pb2.SaverDef.V1, saver_pb2.SaverDef.LEGACY)): - raise ValueError("The specified path: %s is a file." - " Please specify only the path prefix" - " to the checkpoint files." % save_path) logging.info("Restoring parameters from %s", save_path) if context.in_graph_mode(): sess.run(self.saver_def.restore_op_name, diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py index c5a6f49df599434ab3bc1a9fe3d85db6f824071e..b758ceaab02d3ad7b79adb64b7a724327f2d6623 100644 --- a/tensorflow/python/training/saver_test.py +++ b/tensorflow/python/training/saver_test.py @@ -53,6 +53,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables @@ -66,6 +67,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.summary import summary from tensorflow.python.training import adam +from tensorflow.python.training import checkpointable from tensorflow.python.training import gradient_descent from tensorflow.python.training import queue_runner_impl from tensorflow.python.training import saver as saver_module @@ -2039,6 +2041,80 @@ class MetaGraphTest(test.TestCase): self._testGraphExtensionRestore(test_dir) self._testRestoreFromTrainGraphWithControlContext(test_dir) + def _testWhileLoopAndGradientSerDes(self, outer_body_fn): + # Build a while loop with `outer_body_fn`, export it, and verify that it can + # be imported and the gradient can be built and run correctly. + + test_dir = self._get_test_dir("nested_control_flow") + filename = os.path.join(test_dir, "metafile") + saver_ckpt = os.path.join(test_dir, "saver.ckpt") + + # Create while loop using `outer_body_fn`. + with ops_lib.Graph().as_default(): + var = variables.Variable(0) + var_name = var.name + _, output = control_flow_ops.while_loop(lambda i, x: i < 5, outer_body_fn, + [0, var]) + output_name = output.name + init_op = variables.global_variables_initializer() + + # Generate a MetaGraphDef containing the while loop. + with session.Session() as sess: + sess.run(init_op) + sess.run(output) + saver = saver_module.Saver() + saver.save(sess, saver_ckpt) + saver.export_meta_graph(filename) + + # Build and run the gradients of the while loop. We use this below to + # verify that the gradients are correct with an imported MetaGraphDef. + grad = gradients_impl.gradients([output], [var]) + with session.Session() as sess: + sess.run(init_op) + expected_grad_value = sess.run(grad) + + # Restore the MetaGraphDef into a new Graph. + with ops_lib.Graph().as_default(): + with session.Session() as sess: + saver = saver_module.import_meta_graph(filename) + saver.restore(sess, saver_ckpt) + + # Make sure we can still build gradients and get the same result. + var = ops_lib.get_default_graph().get_tensor_by_name(var_name) + output = ops_lib.get_default_graph().get_tensor_by_name(output_name) + grad = gradients_impl.gradients([output], [var]) + + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + actual_grad_value = sess.run(grad) + self.assertEqual(expected_grad_value, actual_grad_value) + + def testNestedWhileLoopsSerDes(self): + # Test two simple nested while loops. + def body(i, x): + _, r = control_flow_ops.while_loop(lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0]) + return i + 1, x + r + self._testWhileLoopAndGradientSerDes(body) + + def testNestedControlFlowSerDes(self): + # Test while loop in a cond in a while loop. + # pylint: disable=g-long-lambda + def body(i, x): + cond_result = control_flow_ops.cond( + i > 0, + lambda: control_flow_ops.while_loop( + lambda j, y: j < 3, + lambda j, y: (j + 1, y + x), + [0, 0])[1], + lambda: x) + return i + 1, cond_result + # pylint: enable=g-long-lambda + self._testWhileLoopAndGradientSerDes(body) + def testStrippedOpListDef(self): with self.test_session(): # Creates a graph. @@ -2660,5 +2736,92 @@ class ScopedGraphTest(test.TestCase): self.assertEqual(2.0, var_dict2["variable2:0"].eval()) +class _OwnsAVariableSimple(checkpointable.CheckpointableBase): + """A Checkpointable object which can be saved using a tf.train.Saver.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + return {checkpointable.VARIABLE_VALUE_KEY: self.non_dep_variable} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + +class _MirroringSaveable( + saver_module.BaseSaverBuilder.ResourceVariableSaveable): + + def __init__(self, primary_variable, mirrored_variable): + self._primary_variable = primary_variable + self._mirrored_variable = mirrored_variable + super(_MirroringSaveable, self).__init__( + self._primary_variable, "", self._primary_variable.name) + + def restore(self, restored_tensors, restored_shapes): + """Restore the same value into both variables.""" + tensor, = restored_tensors + return control_flow_ops.group( + self._primary_variable.assign(tensor), + self._mirrored_variable.assign(tensor)) + + +class _OwnsMirroredVariables(checkpointable.CheckpointableBase): + """A Checkpointable object which returns a more complex SaveableObject.""" + + def __init__(self): + self.non_dep_variable = variable_scope.get_variable( + name="non_dep_variable", initializer=6., use_resource=True) + self.mirrored = variable_scope.get_variable( + name="mirrored", initializer=15., use_resource=True) + + def _gather_saveables_for_checkpoint(self): + saveable = _MirroringSaveable( + primary_variable=self.non_dep_variable, + mirrored_variable=self.mirrored) + return {checkpointable.VARIABLE_VALUE_KEY: saveable} + + # The Saver sorts by name before parsing, so we need a name property. + @property + def name(self): + return self.non_dep_variable.name + + +@test_util.with_c_api +class CheckpointableCompatibilityTests(test.TestCase): + + # TODO(allenl): Track down python3 reference cycles in these tests. + @test_util.run_in_graph_and_eager_modes() + def testNotSaveableButIsCheckpointable(self): + v = _OwnsAVariableSimple() + saver = saver_module.Saver(var_list=[v]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + with self.test_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + + @test_util.run_in_graph_and_eager_modes() + def testMoreComplexSaveableReturned(self): + v = _OwnsMirroredVariables() + saver = saver_module.Saver(var_list=[v]) + test_dir = self.get_temp_dir() + prefix = os.path.join(test_dir, "ckpt") + self.evaluate(v.non_dep_variable.assign(42.)) + with self.test_session() as sess: + save_path = saver.save(sess, prefix) + self.evaluate(v.non_dep_variable.assign(43.)) + self.evaluate(v.mirrored.assign(44.)) + saver.restore(sess, save_path) + self.assertEqual(42., self.evaluate(v.non_dep_variable)) + self.assertEqual(42., self.evaluate(v.mirrored)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/training/slot_creator.py b/tensorflow/python/training/slot_creator.py index 18a5b89d300b0eafa812211d5287b15018e7d936..75ef3d5976aba9f0cbe849d9f6984646d71a29ef 100644 --- a/tensorflow/python/training/slot_creator.py +++ b/tensorflow/python/training/slot_creator.py @@ -48,11 +48,6 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables -def _is_resource(v): - """Returns true if v is something you get from a resource variable.""" - return isinstance(v, resource_variable_ops.ResourceVariable) - - def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): """Helper function for creating a slot variable.""" @@ -65,7 +60,7 @@ def _create_slot_var(primary, val, scope, validate_shape, shape, dtype): shape = shape if callable(val) else None slot = variable_scope.get_variable( scope, initializer=val, trainable=False, - use_resource=_is_resource(primary), + use_resource=resource_variable_ops.is_resource_variable(primary), shape=shape, dtype=dtype, validate_shape=validate_shape) variable_scope.get_variable_scope().set_partitioner(current_partitioner) diff --git a/tensorflow/python/util/compat_internal.py b/tensorflow/python/util/compat_internal.py index d8b9319f668b85e227b9a0578b63fd46af0f2c13..1905c3e3832550906c601bd4545e72b5bd135e2c 100644 --- a/tensorflow/python/util/compat_internal.py +++ b/tensorflow/python/util/compat_internal.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.util.compat import as_str_any + def path_to_str(path): """Returns the file system path representation of a `PathLike` object, else as it is. diff --git a/tensorflow/python/util/decorator_utils.py b/tensorflow/python/util/decorator_utils.py index df259c7f7c29f9a4b674d3e980b33d6dcf323769..7b4363c0e40802779cf47c75c5a5e5a901da37e2 100644 --- a/tensorflow/python/util/decorator_utils.py +++ b/tensorflow/python/util/decorator_utils.py @@ -82,7 +82,7 @@ def add_notice_to_docstring( lines = _normalize_docstring(doc).splitlines() lines[0] += ' ' + suffix_str - notice = [''] + notice + [instructions] + notice = [''] + notice + ([instructions] if instructions else []) if len(lines) > 1: # Make sure that we keep our distance from the main body diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py index c4168f7b1ac80976a957e96c79c72fe3b288d622..a7cead5555df05a988418e8b08cd72999ba3c34e 100644 --- a/tensorflow/python/util/tf_inspect.py +++ b/tensorflow/python/util/tf_inspect.py @@ -134,6 +134,11 @@ def getmembers(object, predicate=None): # pylint: disable=redefined-builtin return _inspect.getmembers(object, predicate) +def getmodule(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.getmodule.""" + return _inspect.getmodule(object) + + def getmro(cls): """TFDecorator-aware replacement for inspect.getmro.""" return _inspect.getmro(cls) @@ -144,6 +149,11 @@ def getsource(object): # pylint: disable=redefined-builtin return _inspect.getsource(tf_decorator.unwrap(object)[1]) +def isbuiltin(object): # pylint: disable=redefined-builtin + """TFDecorator-aware replacement for inspect.isbuiltin.""" + return _inspect.isbuiltin(tf_decorator.unwrap(object)[1]) + + def isclass(object): # pylint: disable=redefined-builtin """TFDecorator-aware replacement for inspect.isclass.""" return _inspect.isclass(tf_decorator.unwrap(object)[1]) diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py index a9e8ffb30c3392251c2bf7076e02aafd2338696b..129408449ebb45ac3a322f163a13b705cbb31f0c 100644 --- a/tensorflow/python/util/tf_inspect_test.py +++ b/tensorflow/python/util/tf_inspect_test.py @@ -124,6 +124,17 @@ class TfInspectTest(test.TestCase): inspect.getmembers(TestDecoratedClass), tf_inspect.getmembers(TestDecoratedClass)) + def testGetModule(self): + self.assertEqual( + inspect.getmodule(TestDecoratedClass), + tf_inspect.getmodule(TestDecoratedClass)) + self.assertEqual( + inspect.getmodule(test_decorated_function), + tf_inspect.getmodule(test_decorated_function)) + self.assertEqual( + inspect.getmodule(test_undecorated_function), + tf_inspect.getmodule(test_undecorated_function)) + def testGetSource(self): expected = '''@test_decorator('decorator') def test_decorated_function_with_defaults(a, b=2, c='Hello'): @@ -133,6 +144,19 @@ def test_decorated_function_with_defaults(a, b=2, c='Hello'): self.assertEqual( expected, tf_inspect.getsource(test_decorated_function_with_defaults)) + def testIsBuiltin(self): + self.assertEqual( + tf_inspect.isbuiltin(TestDecoratedClass), + inspect.isbuiltin(TestDecoratedClass)) + self.assertEqual( + tf_inspect.isbuiltin(test_decorated_function), + inspect.isbuiltin(test_decorated_function)) + self.assertEqual( + tf_inspect.isbuiltin(test_undecorated_function), + inspect.isbuiltin(test_undecorated_function)) + self.assertEqual(tf_inspect.isbuiltin(range), inspect.isbuiltin(range)) + self.assertEqual(tf_inspect.isbuiltin(max), inspect.isbuiltin(max)) + def testIsClass(self): self.assertTrue(tf_inspect.isclass(TestDecoratedClass)) self.assertFalse(tf_inspect.isclass(test_decorated_function)) diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index b6abd42767f7b7048dce30d2b7a5b524513ff79c..61cf4ba7eac1f9482e3c1b179f35434a2a65d955 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -577,7 +577,7 @@ class ScopedFilterDescriptor { // A helper function to decide whether to enable the TENSOR_OP_MATH math type static bool TensorOpMathEnabled() { static bool is_enabled = [] { - bool is_disabled; + bool is_disabled = false; TF_CHECK_OK( tensorflow::ReadBoolFromEnvVar("TF_DISABLE_CUDNN_TENSOR_OP_MATH", /*default_val=*/false, &is_disabled)); @@ -586,6 +586,25 @@ static bool TensorOpMathEnabled() { return is_enabled; } +// A helper function to decide whether to use CUDNN_BATCHNORM_SPATIAL_PERSISTENT +// in batchnorm. This mode can be faster in some tasks because an optimized path +// may be selected for CUDNN_DATA_FLOAT and CUDNN_DATA_HALF data types, compute +// capability 6.0 or higher. The reason we set it to false by default is that +// this mode may use scaled atomic integer reduction that may cause a numerical +// overflow for certain input data range. +// TODO(yangzihao): Use autotune to choose between this mode and +// CUDNN_BATCHNORM_SPATIAL mode. +static bool BatchnormSpatialPersistentEnabled() { + static bool is_enabled = [] { + bool is_enabled = false; + TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar( + "TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT", + /*default_val=*/false, &is_enabled)); + return is_enabled; + }(); + return is_enabled; +} + // Turns a ConvolutionDescriptor structure into a cudnn convolution handle // within a scope. class ScopedConvolutionDescriptor { @@ -2773,6 +2792,11 @@ bool CudnnSupport::DoBatchNormalizationForwardImpl( ScopedTensorDescriptor scale_offset_descriptor{ parent_, scale_offset_desc, ToCudnnDataType(scale_data_type)}; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; +#if CUDNN_VERSION >= 7000 + if (BatchnormSpatialPersistentEnabled() && is_training) { + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } +#endif float one = 1.0; float zero = 0.0; @@ -2874,6 +2898,11 @@ bool CudnnSupport::DoBatchNormalizationBackwardImpl( parent_, scale_offset_desc, static_cast(cudnn_scale_type)}; cudnnBatchNormMode_t mode = CUDNN_BATCHNORM_SPATIAL; +#if CUDNN_VERSION >= 7000 + if (BatchnormSpatialPersistentEnabled()) { + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; + } +#endif float one = 1.0; float zero = 0.0; diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 748ce12cfe017a91823110f72414c75f37602ab7..818d67f7b5be1e8f2db66b24976a529b361a4990 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -498,6 +498,9 @@ def tf_gen_op_wrappers_cc(name, # is invalid to specify both "hidden" and "op_whitelist". # cc_linkopts: Optional linkopts to be added to tf_cc_binary that contains the # specified ops. +# gen_locally: if True, the genrule to generate the Python library will be run +# without sandboxing. This would help when the genrule depends on symlinks +# which may not be supported in the sandbox. def tf_gen_op_wrapper_py(name, out=None, hidden=None, @@ -508,7 +511,8 @@ def tf_gen_op_wrapper_py(name, generated_target_name=None, op_whitelist=[], cc_linkopts=[], - api_def_srcs=[]): + api_def_srcs=[], + gen_locally=False): if (hidden or hidden_file) and op_whitelist: fail('Cannot pass specify both hidden and op_whitelist.') @@ -563,6 +567,7 @@ def tf_gen_op_wrapper_py(name, outs=[out], srcs=api_def_srcs + [hidden_file], tools=[tool_name] + tf_binary_additional_srcs(), + local = (1 if gen_locally else 0), cmd=("$(location " + tool_name + ") " + api_def_args_str + " @$(location " + hidden_file + ") " + ("1" if require_shape_functions else "0") + " > $@")) @@ -572,6 +577,7 @@ def tf_gen_op_wrapper_py(name, outs=[out], srcs=api_def_srcs, tools=[tool_name] + tf_binary_additional_srcs(), + local = (1 if gen_locally else 0), cmd=("$(location " + tool_name + ") " + api_def_args_str + " " + op_list_arg + " " + ("1" if require_shape_functions else "0") + " " + @@ -612,7 +618,7 @@ def tf_cc_test(name, srcs=srcs + tf_binary_additional_srcs(), copts=tf_copts() + extra_copts, linkopts=select({ - "//tensorflow:android": [ + clean_dep("//tensorflow:android"): [ "-pie", ], clean_dep("//tensorflow:windows"): [], diff --git a/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt b/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt index 75361803a3991f380d6be2485cfd3d05fd1572e1..cdaeb55e30865e082054085f47d6a071ebf3affd 100644 --- a/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-graph.pbtxt @@ -130,6 +130,10 @@ tf_class { name: "prevent_fetching" argspec: "args=[\'self\', \'op\'], varargs=None, keywords=None, defaults=None" } + member_method { + name: "switch_to_thread_local" + argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + } member_method { name: "unique_name" argspec: "args=[\'self\', \'name\', \'mark_as_used\'], varargs=None, keywords=None, defaults=[\'True\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt index bc7cf7267f7d23121402e63903f01ddc6caa2e04..5a02bb2175e2d6ad71722799143090f2735c1a37 100644 --- a/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.-variable.pbtxt @@ -1,6 +1,7 @@ path: "tensorflow.Variable" tf_class { is_instance: "" + is_instance: "" is_instance: "" member { name: "SaveSliceInfo" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt index 874a73f661d782ff5637b751f104fd2209734599..be9ba4ce85bd5b9905a39e3f45873c534594e15f 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt index 8da2a2b6867a3f9a3d82fcdb76ac4a62d5cee825..91fca67b6b5b1187b61f398a152793362c0c6e30 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt index efc441ae2f2a00f663c11f84c1155bece0c8e08a..cd4f72fcf839fa89f25c7ed115ee6c61294283c3 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt index 20ce87987060d9013bd071d6fc9f1f4f33467121..303fd74a64d0c7f5a0292a4eaabec63455c29381 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt index 73211aaf8ba5f925982afe3d17c4b8f009250cb8..c97ea7969eff3e6952a604e72ce140b49d304461 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt index 27a159639d2098aace2e69718d9ac4e38a29fdc3..4b5b5bf0e3599a81e2e853ae8ba34ef12cc63097 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt index 76f527f796e95f342eb144ae3de87ff234338021..42a0d595216ad28363727b9d7c066fc37fddd02c 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt @@ -44,7 +44,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt index c45318b98a034255d32c326179813de14cf1d4c8..2de52d6c57cc70b562c3c10b7f23cd15b63e25f8 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt index 04a2aa080d0704a8b7ec98f8eafda4bd1944e567..e552f33720bb939b8a98d34ef3de78bda7db976c 100644 --- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt @@ -45,7 +45,7 @@ tf_class { } member_method { name: "predict" - argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], " } member_method { name: "train" diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt index baedf596e8fbce921ed7e0570542b8a11655dba4..bda1c2bf85977e69b0969bc8b6056710d88ca910 100644 --- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.image.pbtxt @@ -100,6 +100,10 @@ tf_module { name: "hsv_to_rgb" argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "is_jpeg" + argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "non_max_suppression" argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt index 2bf584fa2936990b467b2da9c48620a31814691a..241db8956a5bc01a058048d3b21b2e1cbe56c92f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt @@ -1,9 +1,8 @@ path: "tensorflow.keras.Model" tf_class { is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -38,6 +37,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -108,11 +111,11 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'inputs\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -120,7 +123,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" @@ -136,11 +139,11 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt index 0a6096813155d59eb1c7920f2bcd250ed9730982..9673a508d610778029013b9388ddafd34713f301 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt @@ -1,10 +1,9 @@ path: "tensorflow.keras.Sequential" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -39,6 +38,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -125,7 +128,7 @@ tf_class { } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -133,7 +136,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" @@ -149,7 +152,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" } member_method { name: "compile" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt index ea4d5143540611f0585b67910cb319454b8560dc..454823fd23e72c6aa6bf6aa608707fa3b893b986 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-base-logger.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'self\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt index 0e6901f28affdfc73092c2b9f3af07d17db61a9f..543de0ad48b86502fc83374e5e6d82822485f331 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.callbacks.-progbar-logger.pbtxt @@ -5,7 +5,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'count_mode\'], varargs=None, keywords=None, defaults=[\'samples\'], " + argspec: "args=[\'self\', \'count_mode\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'samples\', \'None\'], " } member_method { name: "on_batch_begin" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt index f4ab075959906cdf350ec5d49dc86f928b7eb7ae..041acf29ff76d7271913204c817cc8c3d47429a5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activation.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Activation" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt index eb558cddafc3972127786353072767f0d53bf174..48143b2cd66b4cec6f8833be71f27645be9dc898 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-activity-regularization.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ActivityRegularization" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt index 770a107b664d7ab0a8aedf292a34d4258a201859..11f78fed9733d8a17072f73a78799bcae823d469 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-add.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Add" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt index 0ce42b706ec20a8ea1cc83ec95cb64d9be2e5710..84eb8256325974619a9252d11da139032cdddfa3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-alpha-dropout.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.AlphaDropout" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt index d6c98fa225ce924bc8e20f8531516eaed4d32ffb..ab377a248f093530396cce7bc5baacfeba237e2b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt index 754fd310c6d8ddb994db0590342b29f8cb7abd71..c2edd79f5263b52f0a1ab9df9cb29517b33da7de 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt index 9b62880c7931d151fb98cc1dc3149dcbd4dd103d..f3f37eed9946132a91c8c872411f164df7d1691d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average-pooling3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt index b371ad148cee16dd243869d929e0c1c002794682..31d1d1c049c3009b37d80839832ddf44ca1cd6d6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-average.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Average" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt index 3e2aba55fd63326bb0e232fdce06f32884db7a0a..6582e1b18eb1982719cbac6b679ec830ce5938b3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt index fb37308cce0124538648c3837e1e802794d7f1ae..12f66095d2de014e4e7dfc02f5a7a2341db428f5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt index 813470ffc7c87727eb0b958e54806f530399806a..3a45fa180ee90a921fe4bbde0924cb8364ddb9e5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-avg-pool3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt index e251ac18e511b58a49816126d9941b98e4f91088..a0f272c1788a3fd197bae6f5583e009f97dd3c56 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-batch-normalization.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.BatchNormalization" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt index db26c3e568da09d1523003ab538d565c6c2e1464..9c7d3154ad43cb75e310e9c8dd3a2f7a46153fce 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-bidirectional.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Bidirectional" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -73,6 +73,10 @@ tf_class { name: "scope_name" mtype: "" } + member { + name: "trainable" + mtype: "" + } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt index ff08def0a08e5201bc01d61be3f2d66d712c384b..949b225e545132cf2464fd01b3607e0ef2c44b7f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-concatenate.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Concatenate" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt index 6db22ca0320519fd9c101456c9c9c0e26a9a11e0..a736c84a102e4780ca04ff1fad92f9310c841814 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt index 577f206e3510a9995d5d383ac440b4f68ea39fe5..95f9afed28961e88c8329117ed6714b91e72b10d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt index 72924c32b43e5edb39938cc0cd909cffefa61be1..38ba15400a49088fd5337a43f7d6ed3b4067a9d3 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d-transpose.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt index 16be08d9b2bae8fe1faecf34c4d87ac9b9baf142..bc84e2a97e549bed7a6a4e255b3c3d3fe6cd5250 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt index 11e05f884d781166616a9c9a61dacbc8fdae6ae3..0802578c227b2d5ed56e5180075f604680d94397 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d-transpose.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt index 72b72d6b3b1e410dda0b0a529449f0135203fc1b..8ad4646c749b22d3c8bc79ada9d3d875bdb34201 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-conv3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt index ee93247f63ed700dc6058041bd0ea4ff5c879078..110e267b75c59793af194982a404a318d81ded7c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt index e5023287e5f38553f3553a37b5a908790072b5c7..24cfc83af61f3742229ee3a7bfcb3b399db53291 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt index ba38cb7121c9d312e7ba9d7147bdc67673d1ad2e..c56e89187f26fee8447889465244398c600c6e18 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt index 58724a1e1661609ef3c000c7ca1dfe9b3235acff..3674f2746caf04f56618bbd2824c9ae54ade21b8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt index 98d52c430c659d0fc3e9299f7bede9190dad2fcf..5a8f9d770280b5ca1f9c91d46042dcca061f31c9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-convolution3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt index 33b6ebe1af731f66f88a9493502f69049ab34b42..caa748be8150be045ccaede0c30f7d6eee66c30e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Cropping1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt index 4b241ebb0f68c270a9448b02138d44f82211f418..97bd4a265a9b0139207e795252679d915ba1bde0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Cropping2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt index 1856a9ee21347ed6ca3dd592517eb644e205a5b7..20c43eeed1c06d5970684d792c193913eed8b5d4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-cropping3-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Cropping3D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt index a8c37af31f649d28ca2ab7614178f2dee58c13fc..256f0e4bdf40a4bf10f65f873d50a5d328091740 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dense.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Dense" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt index 07d3f023e54105c606b198c05750ffa78ee5d0c8..d1e53f900c2dd6630499486083d33e4d193b30fd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dot.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Dot" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt index e2e21b5f123f63fa38cb0e344be9a12fc091f20b..b010ff6805f3d032ba8841fecb8a98aabd604a88 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-dropout.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Dropout" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt index 92b9760d53e35d3e5066a730bb5cbda45492cc64..fffd3854bbbf486c67d32f259e605b14b2c42ede 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-e-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ELU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt index 83c528b40117222ac2b3e85ad338459948d0aa8c..1155fe03fc53d5ea046fd3c2d04353116ab98273 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-embedding.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Embedding" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt index 73609752886c8c57a78f6bc02cc46d2c7ff6e996..5e4bebb15b54742ce1d6a4bf31482046fe6f5be5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-flatten.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Flatten" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt index b329f1c46bb07ab7684dec6aaf45a20b98c27ed9..cb9bb3d82151eb6f9f4e4c90133519a1d6fbde63 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.GRUCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt index c741d4d6e6cf8da9712e68f86abe64e2828823da..9a36e806498f3fc8b16c1b2587244bdad515ad5f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-g-r-u.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GRU" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt index 57596badf1881950270fa6d3c074afb65daaa8eb..eb32238e151a3665f852df3b693b716e6bf04fa2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-dropout.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.GaussianDropout" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt index 3829353cc3c195a750ad862707c5c8563e203fba..37fc8e29aefda3cc1c47b4934d7d421b3f438eb5 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-gaussian-noise.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.GaussianNoise" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt index e53e78a977b32eaf2e31867044aedd39ab2dd34f..490816458b08a0663d07d62a2d9dfad7f1c95dd8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAveragePooling1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt index 48fcd1044e06b2fe61aadb6c3675ce82197ff003..ab49f67f336d256e46282c5e5a18e498fa56b9c7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAveragePooling2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt index 66c06ed47289eb2d83d97778a7b13dab821722d2..3d7cb3ba491c91341b37270a69e758f08161ec53 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAveragePooling3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt index 4f2420f74ab3069952e4a44bf61e5e12b3e80ea3..c99ddab4f3277d43e663b9af6e8ce7dfe6307f05 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAvgPool1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt index 7912a6d933b851521358e0246d04688da410b909..290d2eaebe8ef9f5bab2eebe224e261561c9a86d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAvgPool2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt index d5b2d2c274ad97071497045271c0a595f8e0e062..cf63069641a61c029e6851be35c2fa8e8f433b22 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalAvgPool3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt index d88ff17eb6df7bbba7d3af4344fc8ddc367ae44c..2dadc67c0939df6a806bd4b1a0a91b81365aea3c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool1-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPool1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt index c8cc5a0ddfdd54cbb47de922591a9842abf63396..1a1a1dcf64b726092f93c00d46c009d0b49b7baf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool2-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPool2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt index 7956c5a340d963cfd5976e8af56da222848a164a..44898e23ad8faff5c806de20f18cf62bfd6708f2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pool3-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPool3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt index 0a7e16413dfbd80d448eb1bad5771915475d96b2..941d867d24c42684d78b61cf06017ecde7c6eefc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPooling1D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt index 6c8a58a996f5313ea48e395e7e443a7c21f198ee..9a5a6325f89b2b4253622e918c17e9a86701b8e7 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPooling2D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt index 7678ce8aab63fcfa76c0ac61346a723c1dfe1ee7..7a0c1932f6239897bdbde24f3a7269027cb008ca 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.GlobalMaxPooling3D" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt index d46fd41a3f33002a9bbe755851278c9729ccd1d1..f679c1d00692e9e31eba4752e4b9ed7cb2494d8e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-input-layer.pbtxt @@ -1,8 +1,7 @@ path: "tensorflow.keras.layers.InputLayer" tf_class { - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt index 3b171b137af699c9608494a17c5651b439fe4545..ad1e7f2cad74a72e5489bfeb857a90511ececb03 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LSTMCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt index 29d9cf78ab5ed3bdd1a488359b59cf7171e7e051..6dad4b48979b11afbc50e33e4b3766429b1a9541 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-l-s-t-m.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.LSTM" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt index ca0144929942f7024a4e8bac5552bf0547ceb56d..fa45d8c9028639c4c1f9dbf930cda107ea8819d8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-lambda.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Lambda" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt index c52ad727545c0bf4f199714d71180eac3f1bf62a..023d6c0d69790ab2e6203d3331f63fdb44c85b2e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-layer.pbtxt @@ -1,6 +1,6 @@ path: "tensorflow.keras.layers.Layer" tf_class { - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt index 8134fb738683b79764662d9ea7f721fe04751162..e429fced77bd163c381466df4ba18e2672e31b0d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-leaky-re-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LeakyReLU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt index c5d452300947d7f74e7458e2a04bfdfabb1c1da2..462568124f72b334a7a83bdbd0adc4ac1b14218e 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LocallyConnected1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt index bcbed9241b525a953c8b499197facaefebe8cc44..11bf6a2b426d7f76ab67b0450048d56e02211702 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-locally-connected2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.LocallyConnected2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt index 244e79b4ffe60ddd6aa56d2780d80dfd66c494a9..a9324488911909395929d8e08c7429fab0bb5058 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-masking.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Masking" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt index 56cbf5df785ef0e2614ea7e9e6cfe1335e148eec..6ff2adddac4bf679242856f7ed21a6ed05f949b8 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt index 33c2d30e86f9cdc3fb9f4f498bfc2c94497fe2dd..2957673d4d461b9111a2a40bac0f2489fa798ff0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt index 94f91059b7a1e291c38fe0045accc6c03f226603..2191c10b7399feb8cf488cb71e8047b746daf523 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pool3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt index 247230a6d68b8ea93a30a2f5846d8baaa78cb13e..af750ac1b61e23509902aae3ec9c91bbb063509d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt index 8d61b67e7ce9564d31b0bd904a58540d19c89172..9046061510828af39e71a2eee6e29cc8ca7c92d1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt index ad2e30802006e934730e5c75247e958329f7121c..a40666807be0b50e53809f254dc6cbc16a403209 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-max-pooling3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt index ff0db15f190675d533c50c277eb1cb60e0b95e55..65378cef42215ddbc471e4014424fde3010d23b6 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-maximum.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Maximum" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt index 1d3f33f04516345ee32f16befe0d7200d2cdad00..b037559e02a8200f4251f2d8604bbd7c2595cb15 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-multiply.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.Multiply" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt index c86bc49b22a8cc3e004a77f4a21594aacb2c665a..b3a7f47fa59af9dccdc74af3b138e0ed66234bae 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-p-re-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.PReLU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt index 2043e1a1263f0f0745b7c6446cc670fd6b0f0000..b2f22f7da3f5f33e7c750d22690c5732c4fe7643 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-permute.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Permute" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt index ad539a7c4c5362500baef0a9c89d054762bbb47d..792eacf90dc96d6c4f85d014713a35eeb623c4ec 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-r-n-n.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.RNN" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt index 4b0e98520a0dd86c085fa7345af445e1ae253d3b..5b79a021caa0a52776cf87207a434e5044caa8a9 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-repeat-vector.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.RepeatVector" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt index 34bc71af8a26ff6e4d7c81a3877751df5209906f..99c64505ee6e92f09d9ee25372ed173583988885 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-reshape.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Reshape" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt index dd67b76523cc50409516e29f963f59d039455bfd..d5873ccf7659d162c02e5855df5919046c2e3554 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv1-d.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt index 5d898fb2bd86b39cb8fab755382bb96cce231fa6..76b4c10a46161d598b2eaa9c871e665a47149a4d 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-conv2-d.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt index bf62c095e7cc3fbeac95919a0f9fdc545efd3d25..40cd87de5fa9c05cfb4a2978d847e573e122bccf 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution1-d.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt index c758d87993b3acba88a13c7bc9eaeee929a22652..c44c0da1485821ec6b69f20efa20fc782bd9edbc 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-separable-convolution2-d.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt index 6e3cde3e3eaba4f9985411d66a220f7cdd4ee7ad..bd70c31c38032da5ca6953a208d58a4cd5d04d3f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.SimpleRNNCell" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt index 6fafc77b947d0df11755e3136ed2e7a14c148081..de717976cf285827ccedbbcc91f88d5d95df58fa 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-simple-r-n-n.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.SimpleRNN" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt index ee4b2fa39ed34a544ee800e9370e4f34c4a17041..a93b7b8f6e116ce955b816a849ff68926f7adec2 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-softmax.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Softmax" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt index e4727072e375b9fc4dc99a1536eaaf3df5415369..4dc24b195e3ae0fe896e9f844968b2c0f174ca77 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt index c5ff7043115ccdd3bc4a1147790b20feda410f65..a3bb1cc414e9fdcdda9f2b4588ab44005b0c51fd 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt index 476a7f362cf88e234e964f6f6645ee4ed0cbaff8..f9a78106fa406992d4672c6411a65d9e33f709a4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt @@ -3,7 +3,7 @@ tf_class { is_instance: "" is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt index 90c37bd98650db42abceb9508c7dc7e564cbee68..5aa21f402228794ae9d1053e0958fd82e65c8a51 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.StackedRNNCells" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -122,7 +122,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'states\'], varargs=None, keywords=kwargs, defaults=None" + argspec: "args=[\'self\', \'inputs\', \'states\', \'constants\'], varargs=None, keywords=kwargs, defaults=[\'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt index ef31c5443efa0c0e5a7a2e0a422d2a9c9c49baaf..88e8a465725998997d2637cc833d87024aaf8a21 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ThresholdedReLU" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt index 40aa782a02b6f2ce71860f0df5c3e61ead68e337..f2a7673998d597370b4565c52b26399b14038f8f 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-time-distributed.pbtxt @@ -2,7 +2,7 @@ path: "tensorflow.keras.layers.TimeDistributed" tf_class { is_instance: "" is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -69,6 +69,10 @@ tf_class { name: "scope_name" mtype: "" } + member { + name: "trainable" + mtype: "" + } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt index a81b83be49e0073f242efc6890e419b4fe172ab2..4db82ddfa931b7dc4c0588184261f27239688a56 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.UpSampling1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt index 5403279d45ec7b93bae7907b891c659a043e96d0..61e65ad56df0a380417ec8ef70af3d200c90f72c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.UpSampling2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt index 96c337caf28d43fabd0b90df016f4e8ab0c408db..3d9402db4e2a52c8b746b256aea00392f26f17d1 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-up-sampling3-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.UpSampling3D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt index 27a54382a47dffd17810ebdcb45cb838c1442635..0223799ed4cce948fdc7b90a3957a72f16d9eed4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-wrapper.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.Wrapper" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -68,6 +68,10 @@ tf_class { name: "scope_name" mtype: "" } + member { + name: "trainable" + mtype: "" + } member { name: "trainable_variables" mtype: "" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt index b81a4b1c50b22f13eacb521cfc8bc288bd40c81f..2e4429833a91bede84e9a19b9c06c7a9146edee4 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding1-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ZeroPadding1D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt index 1a26f2f3c9bbaa2aa567e76e1aafe14805ecff38..26cf7b9e49b124108a428801350661f610527e99 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding2-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ZeroPadding2D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt index 310277fe67433fd870ae3d907984f402576925b2..64d35d944795378d3f49060a6b50aac387a1c8a0 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.layers.-zero-padding3-d.pbtxt @@ -1,7 +1,7 @@ path: "tensorflow.keras.layers.ZeroPadding3D" tf_class { is_instance: "" - is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { diff --git a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt index de285c1aab197ea5cae9c94048a5131f8463ebde..42729e4237685638d38301cece6e93383ddfffba 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.metrics.pbtxt @@ -22,7 +22,7 @@ tf_module { } member_method { name: "deserialize" - argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], " } member_method { name: "get" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt index 0b816b58631d12471c2e9db96fc5395796d96ddf..18be9c97014b526118b44544db08c8c86e3dbc2c 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt @@ -1,9 +1,8 @@ path: "tensorflow.keras.models.Model" tf_class { is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -38,6 +37,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -108,11 +111,11 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'inputs\', \'outputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -120,7 +123,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" @@ -136,11 +139,11 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " } member_method { name: "compile" - argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], " } member_method { name: "compute_mask" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt index 7c1bfcb22558ec3a64c63ebbf0466f9114ef68ee..b9346329225280b6f412d7cc892a13a9b8b33cff 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt @@ -1,10 +1,9 @@ path: "tensorflow.keras.models.Sequential" tf_class { - is_instance: "" + is_instance: "" is_instance: "" - is_instance: "" - is_instance: "" - is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" is_instance: "" member { @@ -39,6 +38,10 @@ tf_class { name: "input_spec" mtype: "" } + member { + name: "layers" + mtype: "" + } member { name: "losses" mtype: "" @@ -125,7 +128,7 @@ tf_class { } member_method { name: "add_loss" - argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None" } member_method { name: "add_update" @@ -133,7 +136,7 @@ tf_class { } member_method { name: "add_variable" - argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], " } member_method { name: "add_weight" @@ -149,7 +152,7 @@ tf_class { } member_method { name: "call" - argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], " + argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None" } member_method { name: "compile" diff --git a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt index 3adc6b6faa6f62330f9ac3d621f29adfc380a09d..16e1cbe650e1662f8694fd7137ad20a48a90675b 100644 --- a/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.keras.utils.-progbar.pbtxt @@ -4,7 +4,7 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'target\', \'width\', \'verbose\', \'interval\'], varargs=None, keywords=None, defaults=[\'30\', \'1\', \'0.05\'], " + argspec: "args=[\'self\', \'target\', \'width\', \'verbose\', \'interval\', \'stateful_metrics\'], varargs=None, keywords=None, defaults=[\'30\', \'1\', \'0.05\', \'None\'], " } member_method { name: "add" @@ -12,6 +12,6 @@ tf_class { } member_method { name: "update" - argspec: "args=[\'self\', \'current\', \'values\', \'force\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], " + argspec: "args=[\'self\', \'current\', \'values\'], varargs=None, keywords=None, defaults=[\'None\'], " } } diff --git a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt index 59134f84891ad5518dcb5331ce04475482c8b59e..df74c32e1f10cc7540ef105adef6be681e93d089 100644 --- a/tensorflow/tools/api/golden/tensorflow.layers.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.layers.pbtxt @@ -76,10 +76,6 @@ tf_module { name: "SeparableConv2D" mtype: "" } - member_method { - name: "Input" - argspec: "args=[\'shape\', \'batch_size\', \'name\', \'dtype\', \'sparse\', \'tensor\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"\", \'False\', \'None\'], " - } member_method { name: "average_pooling1d" argspec: "args=[\'inputs\', \'pool_size\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'valid\', \'channels_last\', \'None\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt index 066c4513ff5185b50bdf193f579e71e505dbd3b6..f8d08f1d39a8bfa7d78be106e59d88de75a57823 100644 --- a/tensorflow/tools/api/golden/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.pbtxt @@ -1988,6 +1988,10 @@ tf_module { name: "tile" argspec: "args=[\'input\', \'multiples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "timestamp" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "to_bfloat16" argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'ToBFloat16\'], " diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt index 863beaea4cf05a67e572c97b556bc1eb598d9ced..c02e54adfbd9f33e661453767b517a5f0de90d57 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adadelta-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdadeltaOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt index 0a7aa9b6bc14c95e74ab05a3aeb71b770a918f60..2b619908fc6aea3f4b8e6a57d0dcf85a9854d466 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-d-a-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdagradDAOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt index 83724fea55d005e9476801feb1bf58cb004aa141..2005cf4677c06cf1f8b4207a444690fdd0c2306e 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adagrad-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdagradOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt index e285b27a0531e00d27941fe451570a5056995c17..0a2bae1d9021b20707e03ae5786e71f388266c14 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-adam-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.AdamOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt index fc28577d6ed1328ae85970cf22cc458b7cf54344..847f9ad75998f1bdda8858650091c70fd0b4015b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-ftrl-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.FtrlOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt index bf3c1d81f877e3a8a7e24d5455e9c5bf6a41f764..13a58e0608ed269415ba78d84a03f1bae128e80c 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-gradient-descent-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.GradientDescentOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt index a640c8d2c6366951cbba6a15d2000d9369cbbdbf..bfbc2357a346c7bfef0242a735ab14c5f4005b22 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-momentum-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.MomentumOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt index 6b33c236a35f09422a42a17b3ffddf5ba7b1595f..437efa0a2bd04c308db6186e714a5d8785541fa5 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-optimizer.pbtxt @@ -1,6 +1,8 @@ path: "tensorflow.train.Optimizer" tf_class { is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt index d23fcaed7b4cee397dcf9c51eb3b521e5461c9e5..72f224605f67e72dd78699b5f1a703cc3edd566b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-adagrad-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.ProximalAdagradOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt index b6c03e71d9ffb50bd6377b489fcc444453bd9752..316275b1fb1abd384e193994e35115a1c463f07d 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-proximal-gradient-descent-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.ProximalGradientDescentOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt index 4a82db11cb8d85bd0c44135ecaf507c62fae41a1..af50a1986100d830f0809a3f4a0f01faa8821b3b 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-r-m-s-prop-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.RMSPropOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt index e9131bf544f2e7f08928f46d2be06a00259690be..6edc516c9392fa14f23ffc2a6481ec21216f06cf 100644 --- a/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt +++ b/tensorflow/tools/api/golden/tensorflow.train.-sync-replicas-optimizer.pbtxt @@ -2,6 +2,8 @@ path: "tensorflow.train.SyncReplicasOptimizer" tf_class { is_instance: "" is_instance: "" + is_instance: "" + is_instance: "" is_instance: "" member { name: "GATE_GRAPH" diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index c1e09cc531ed8e8995e3e73b87e96b72fba6c038..2a784973e1098bb1f67eb1b002b7b006f69670ff 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -165,7 +165,7 @@ class ApiCompatibilityTest(test.TestCase): logging.error('%d differences found between API and golden.', diff_count) messages = verbose_diffs if verbose else diffs for i in range(diff_count): - logging.error('Issue %d\t: %s', i + 1, messages[i]) + print('Issue %d\t: %s' % (i + 1, messages[i]), file=sys.stderr) if update_goldens: # Write files if requested. diff --git a/tensorflow/tools/ci_build/builds/test_tutorials.sh b/tensorflow/tools/ci_build/builds/test_tutorials.sh index 67e5af556405a5c659000a07a79a6bd9a1d1e542..db335f14ca4f88ade7a540ffab7ed9de67f1248e 100755 --- a/tensorflow/tools/ci_build/builds/test_tutorials.sh +++ b/tensorflow/tools/ci_build/builds/test_tutorials.sh @@ -277,17 +277,6 @@ test_ptb_word_lm() { fi } - -# ----------------------------------------------------------- -# translate_test -test_translate_test() { - LOG_FILE=$1 - - run_in_directory "${TEST_DIR}" "${LOG_FILE}" \ - "${TF_MODELS_DIR}/tutorials/rnn/translate/translate.py" --self_test=True -} - - # Run the tutorial tests test_runner "tutorial test-on-install" \ "${TUT_TESTS}" "${TF_BUILD_TUT_TEST_BLACKLIST}" "${LOGS_DIR}" diff --git a/tensorflow/tools/ci_build/builds/with_the_same_user b/tensorflow/tools/ci_build/builds/with_the_same_user index 5817716c8dec37dfdfd50defb4b20b1deafced70..d4bf546d401d058bd205a70c147615c8efc4f4ba 100755 --- a/tensorflow/tools/ci_build/builds/with_the_same_user +++ b/tensorflow/tools/ci_build/builds/with_the_same_user @@ -36,8 +36,13 @@ else rm /this_is_writable_file_system fi +if [ -n "${CI_BUILD_USER_FORCE_BADNAME}" ]; then + ADDUSER_OPTS="--force-badname" +fi + getent group "${CI_BUILD_GID}" || addgroup --gid "${CI_BUILD_GID}" "${CI_BUILD_GROUP}" -getent passwd "${CI_BUILD_UID}" || adduser --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" \ +getent passwd "${CI_BUILD_UID}" || adduser ${ADDUSER_OPTS} \ + --gid "${CI_BUILD_GID}" --uid "${CI_BUILD_UID}" \ --gecos "${CI_BUILD_USER} (generated by with_the_same_user script)" \ --disabled-password --home "${CI_BUILD_HOME}" --quiet "${CI_BUILD_USER}" usermod -a -G sudo "${CI_BUILD_USER}" diff --git a/tensorflow/tools/ci_build/install/install_bazel.sh b/tensorflow/tools/ci_build/install/install_bazel.sh index 1df6a84d7c6f86abfb965063625ac43a3f1a57fb..3e27a94cf2bf3110ac181d6ef5a57366be17255f 100755 --- a/tensorflow/tools/ci_build/install/install_bazel.sh +++ b/tensorflow/tools/ci_build/install/install_bazel.sh @@ -15,7 +15,7 @@ # ============================================================================== # Select bazel version. -BAZEL_VERSION="0.10.0" +BAZEL_VERSION="0.11.0" set +e local_bazel_ver=$(bazel version 2>&1 | grep -i label | awk '{print $3}') diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel index d16761c3675942838fd2be0ea6e0b7463a3bf249..22c73c3fe13f2cb763295fa25b43e2f82c0e8962 100644 --- a/tensorflow/tools/docker/Dockerfile.devel +++ b/tensorflow/tools/docker/Dockerfile.devel @@ -57,7 +57,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.8.0 +ENV BAZEL_VERSION 0.11.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu index 4ef37881bc91aaa58bab031c69b4a96c2a9d8ec1..69ba340f9201266fd2c2f86571e83f6acdcda950 100644 --- a/tensorflow/tools/docker/Dockerfile.devel-gpu +++ b/tensorflow/tools/docker/Dockerfile.devel-gpu @@ -66,7 +66,7 @@ RUN echo "startup --batch" >>/etc/bazel.bazelrc RUN echo "build --spawn_strategy=standalone --genrule_strategy=standalone" \ >>/etc/bazel.bazelrc # Install the most recent bazel release. -ENV BAZEL_VERSION 0.8.0 +ENV BAZEL_VERSION 0.11.0 WORKDIR / RUN mkdir /bazel && \ cd /bazel && \ diff --git a/tensorflow/tools/docs/parser.py b/tensorflow/tools/docs/parser.py index 3db164c2b5b78dbcb3c408ce89c067d33c2a2af4..e758229535e7b10994a39cbafb37e116fd2a465c 100644 --- a/tensorflow/tools/docs/parser.py +++ b/tensorflow/tools/docs/parser.py @@ -111,8 +111,8 @@ SYMBOL_REFERENCE_RE = re.compile( r""" # Start with a literal "@{". @\{ - # Group at least 1 symbol: not "}" or "\n". - ([^}\n]+) + # Group at least 1 symbol, not "}". + ([^}]+) # Followed by a closing "}" \} """, diff --git a/tensorflow/tools/docs/parser_test.py b/tensorflow/tools/docs/parser_test.py index 8a0e9af5216c881326449b3e85b94c0be331fa37..fca5436ca5fadd1fb5da07d7523bb51c871164b5 100644 --- a/tensorflow/tools/docs/parser_test.py +++ b/tensorflow/tools/docs/parser_test.py @@ -76,8 +76,9 @@ class ParserTest(googletest.TestCase): pass string = ( - 'A @{tf.reference}, another @{tf.reference}, a member ' - '@{tf.reference.foo}, and a @{tf.third$link `text` with `code` in it}.') + 'A @{tf.reference}, another @{tf.reference$with\nnewline}, a member ' + '@{tf.reference.foo}, and a @{tf.third$link `text` with `code` in ' + 'it}.') duplicate_of = {'tf.third': 'tf.fourth'} index = {'tf.reference': HasOneMember, 'tf.reference.foo': HasOneMember.foo, @@ -93,7 +94,7 @@ class ParserTest(googletest.TestCase): self.assertEqual('A ' 'tf.reference, ' 'another ' - 'tf.reference, ' + 'with\nnewline, ' 'a member ' 'tf.reference.foo, ' 'and a link ' diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD index 8601b3d0f19e49fe1308f2d022ee13572351581e..4fe4fc3b137ddf453c9194424a0c4dc31c5a12c3 100644 --- a/tensorflow/tools/graph_transforms/BUILD +++ b/tensorflow/tools/graph_transforms/BUILD @@ -103,6 +103,7 @@ cc_library( "quantize_nodes.cc", "quantize_weights.cc", "remove_attribute.cc", + "remove_control_dependencies.cc", "remove_device.cc", "remove_ema.cc", "remove_nodes.cc", @@ -133,8 +134,8 @@ cc_library( "//tensorflow/core:tensorflow", "//tensorflow/contrib/rnn:gru_ops_op_lib", "//tensorflow/contrib/rnn:lstm_ops_op_lib", + "//tensorflow/core/kernels:quantization_utils", ] + if_not_windows([ - "//tensorflow/core/kernels:quantized_ops", "//tensorflow/core/kernels:remote_fused_graph_rewriter_transform", "//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform", ]), diff --git a/tensorflow/tools/graph_transforms/README.md b/tensorflow/tools/graph_transforms/README.md index 345d9eadb858cadebe03ecb3297aea52ba54bd37..67badb4869029b684cae05130d3c1e190dfb40d2 100644 --- a/tensorflow/tools/graph_transforms/README.md +++ b/tensorflow/tools/graph_transforms/README.md @@ -639,6 +639,13 @@ specified devices may not be available. In order to work with graphs like these, you can run this transform to wipe the slate clean and delete the device specifier from all ops. +### remove_control_dependencies + +Args: None \ +Prerequisites: None + +Removes all control dependencies from the graph. + ### remove_nodes Args: diff --git a/tensorflow/tools/graph_transforms/remove_control_dependencies.cc b/tensorflow/tools/graph_transforms/remove_control_dependencies.cc new file mode 100644 index 0000000000000000000000000000000000000000..cba6b78fc5c43ca17f4f30930eb74efdb12940a1 --- /dev/null +++ b/tensorflow/tools/graph_transforms/remove_control_dependencies.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace graph_transforms { + +// Remove control depdencies in preparation for inference. +// In the tensorflow graph, control dependencies are represented as extra +// inputs which are referenced with "^tensor_name". +// See node_def.proto for more details. +Status RemoveControlDependencies(const GraphDef& input_graph_def, + const TransformFuncContext& context, + GraphDef* output_graph_def) { + output_graph_def->Clear(); + for (const NodeDef& node : input_graph_def.node()) { + NodeDef* new_node = output_graph_def->mutable_node()->Add(); + *new_node = node; + new_node->clear_input(); + for (const auto& input : node.input()) { + if (input[0] != '^') { + new_node->add_input(input); + } + } + } + return Status::OK(); +} + +REGISTER_GRAPH_TRANSFORM("remove_control_dependencies", RemoveControlDependencies); + +} // namespace graph_transforms +} // namespace tensorflow diff --git a/tensorflow/tools/graph_transforms/remove_nodes.cc b/tensorflow/tools/graph_transforms/remove_nodes.cc index 119b44d6a4a4d066b734ae8a0e655c771087d0db..05f036a86a09b2a6a94e9c1a1220803eabc64da5 100644 --- a/tensorflow/tools/graph_transforms/remove_nodes.cc +++ b/tensorflow/tools/graph_transforms/remove_nodes.cc @@ -81,7 +81,17 @@ Status RemoveNodes(const GraphDef& input_graph_def, return Status::OK(); } const NodeDef& input_node = match.inputs[0].node; - inputs_to_rename[replace_node.name()] = input_node.name(); + string target_name = input_node.name(); + for (const string& input : replace_node.input()) { + if (!input.compare(0, target_name.size(), target_name)) { + if (input.size() == target_name.size() || + input[target_name.size()] == ':') { + target_name = input; + break; + } + } + } + inputs_to_rename[replace_node.name()] = target_name; inputs_to_rename["^" + replace_node.name()] = "^" + input_node.name(); new_nodes->push_back(input_node); diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD.bazel b/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD.bazel new file mode 100755 index 0000000000000000000000000000000000000000..439d86c5d2c10d15f68247c0df42ce488c10d6be --- /dev/null +++ b/tensorflow/tools/integration_tests/gcs_smoke_test/BUILD.bazel @@ -0,0 +1,56 @@ +package(default_visibility = ["//visibility:public"]) + +load("@rbe_integration_test//skylark:integration_tests.bzl", "sut_component", "integration_test") +load("@rbe_integration_test//skylark:toolchains.bzl", "toolchain_container_images") + +sut_component( + name = "gcs", + docker_image = toolchain_container_images()["tensorflow"], + setups = [{ + "program": "setup.sh", + "args": [ + "gs://tensorflow-test-bucket/tf-gcs-test", + ], + "output_properties": ["gcs_path"], + "timeout_seconds": 100, + }], + teardowns = [{ + "program": "teardown.sh", + "args": ["{gcs_path}"], + "timeout_seconds": 100, + }], +) + +py_binary( + name = "gcs_smoke", + srcs = ["gcs_smoke.py"], +) + +sh_binary( + name = "test_wrapper", + srcs = ["test_wrapper.sh"], + data = [ + "gcs_smoke", + ], +) + +integration_test( + name = "gcs_smoke_test", + sut_deps = { + ":gcs": "gcs", + }, + tags = [ + "manual", + "notap", + ], + test = { + "program": ":test_wrapper", + "args": [ + "--gcs_bucket_url={gcs#gcs_path}", + "--num_examples=20", + ], + "timeout_seconds": 250, + }, + test_docker_image = toolchain_container_images()["tensorflow"], + test_type = "MultiMachine", +) diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py b/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py new file mode 100755 index 0000000000000000000000000000000000000000..8438c2156cb09b4d8c9442d9a5f4de67e59272f2 --- /dev/null +++ b/tensorflow/tools/integration_tests/gcs_smoke_test/gcs_smoke.py @@ -0,0 +1,253 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Smoke test for reading records from GCS to TensorFlow.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys +import time + +import numpy as np +import tensorflow as tf +from tensorflow.core.example import example_pb2 +from tensorflow.python.lib.io import file_io + +flags = tf.app.flags +flags.DEFINE_string("gcs_bucket_url", "", + "The URL to the GCS bucket in which the temporary " + "tfrecord file is to be written and read, e.g., " + "gs://my-gcs-bucket/test-directory") +flags.DEFINE_integer("num_examples", 10, "Number of examples to generate") + +FLAGS = flags.FLAGS + + +def create_examples(num_examples, input_mean): + """Create ExampleProto's containing data.""" + ids = np.arange(num_examples).reshape([num_examples, 1]) + inputs = np.random.randn(num_examples, 1) + input_mean + target = inputs - input_mean + examples = [] + for row in range(num_examples): + ex = example_pb2.Example() + ex.features.feature["id"].bytes_list.value.append(str(ids[row, 0])) + ex.features.feature["target"].float_list.value.append(target[row, 0]) + ex.features.feature["inputs"].float_list.value.append(inputs[row, 0]) + examples.append(ex) + return examples + + +def create_dir_test(): + """Verifies file_io directory handling methods.""" + + # Test directory creation. + starttime_ms = int(round(time.time() * 1000)) + dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) + print("Creating dir %s" % dir_name) + file_io.create_dir(dir_name) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Created directory in: %d milliseconds" % elapsed_ms) + + # Check that the directory exists. + dir_exists = file_io.is_directory(dir_name) + assert dir_exists + print("%s directory exists: %s" % (dir_name, dir_exists)) + + # Test recursive directory creation. + starttime_ms = int(round(time.time() * 1000)) + recursive_dir_name = "%s/%s/%s" % (dir_name, + "nested_dir1", + "nested_dir2") + print("Creating recursive dir %s" % recursive_dir_name) + file_io.recursive_create_dir(recursive_dir_name) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Created directory recursively in: %d milliseconds" % elapsed_ms) + + # Check that the directory exists. + recursive_dir_exists = file_io.is_directory(recursive_dir_name) + assert recursive_dir_exists + print("%s directory exists: %s" % (recursive_dir_name, recursive_dir_exists)) + + # Create some contents in the just created directory and list the contents. + num_files = 10 + files_to_create = ["file_%d.txt" % n for n in range(num_files)] + for file_num in files_to_create: + file_name = "%s/%s" % (dir_name, file_num) + print("Creating file %s." % file_name) + file_io.write_string_to_file(file_name, "test file.") + + print("Listing directory %s." % dir_name) + starttime_ms = int(round(time.time() * 1000)) + directory_contents = file_io.list_directory(dir_name) + print(directory_contents) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Listed directory %s in %s milliseconds" % (dir_name, elapsed_ms)) + assert set(directory_contents) == set(files_to_create + ["nested_dir1/"]) + + # Test directory renaming. + dir_to_rename = "%s/old_dir" % dir_name + new_dir_name = "%s/new_dir" % dir_name + file_io.create_dir(dir_to_rename) + assert file_io.is_directory(dir_to_rename) + assert not file_io.is_directory(new_dir_name) + + starttime_ms = int(round(time.time() * 1000)) + print("Will try renaming directory %s to %s" % (dir_to_rename, new_dir_name)) + file_io.rename(dir_to_rename, new_dir_name) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Renamed directory %s to %s in %s milliseconds" % ( + dir_to_rename, new_dir_name, elapsed_ms)) + assert not file_io.is_directory(dir_to_rename) + assert file_io.is_directory(new_dir_name) + + # Test Delete directory recursively. + print("Deleting directory recursively %s." % dir_name) + starttime_ms = int(round(time.time() * 1000)) + file_io.delete_recursively(dir_name) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + dir_exists = file_io.is_directory(dir_name) + assert not dir_exists + print("Deleted directory recursively %s in %s milliseconds" % ( + dir_name, elapsed_ms)) + + +def create_object_test(): + """Verifies file_io's object manipulation methods .""" + starttime_ms = int(round(time.time() * 1000)) + dir_name = "%s/tf_gcs_test_%s" % (FLAGS.gcs_bucket_url, starttime_ms) + print("Creating dir %s." % dir_name) + file_io.create_dir(dir_name) + + num_files = 5 + # Create files of 2 different patterns in this directory. + files_pattern_1 = ["%s/test_file_%d.txt" % (dir_name, n) + for n in range(num_files)] + files_pattern_2 = ["%s/testfile%d.txt" % (dir_name, n) + for n in range(num_files)] + + starttime_ms = int(round(time.time() * 1000)) + files_to_create = files_pattern_1 + files_pattern_2 + for file_name in files_to_create: + print("Creating file %s." % file_name) + file_io.write_string_to_file(file_name, "test file creation.") + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Created %d files in %s milliseconds" % + (len(files_to_create), elapsed_ms)) + + # Listing files of pattern1. + list_files_pattern = "%s/test_file*.txt" % dir_name + print("Getting files matching pattern %s." % list_files_pattern) + starttime_ms = int(round(time.time() * 1000)) + files_list = file_io.get_matching_files(list_files_pattern) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Listed files in %s milliseconds" % elapsed_ms) + print(files_list) + assert set(files_list) == set(files_pattern_1) + + # Listing files of pattern2. + list_files_pattern = "%s/testfile*.txt" % dir_name + print("Getting files matching pattern %s." % list_files_pattern) + starttime_ms = int(round(time.time() * 1000)) + files_list = file_io.get_matching_files(list_files_pattern) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("Listed files in %s milliseconds" % elapsed_ms) + print(files_list) + assert set(files_list) == set(files_pattern_2) + + # Test renaming file. + file_to_rename = "%s/oldname.txt" % dir_name + file_new_name = "%s/newname.txt" % dir_name + file_io.write_string_to_file(file_to_rename, "test file.") + assert file_io.file_exists(file_to_rename) + assert not file_io.file_exists(file_new_name) + + print("Will try renaming file %s to %s" % (file_to_rename, file_new_name)) + starttime_ms = int(round(time.time() * 1000)) + file_io.rename(file_to_rename, file_new_name) + elapsed_ms = int(round(time.time() * 1000)) - starttime_ms + print("File %s renamed to %s in %s milliseconds" % ( + file_to_rename, file_new_name, elapsed_ms)) + assert not file_io.file_exists(file_to_rename) + assert file_io.file_exists(file_new_name) + + # Delete directory. + print("Deleting directory %s." % dir_name) + file_io.delete_recursively(dir_name) + + +def main(argv): + del argv # Unused. + # Sanity check on the GCS bucket URL. + if not FLAGS.gcs_bucket_url or not FLAGS.gcs_bucket_url.startswith("gs://"): + print("ERROR: Invalid GCS bucket URL: \"%s\"" % FLAGS.gcs_bucket_url) + sys.exit(1) + + # Verify that writing to the records file in GCS works. + print("\n=== Testing writing and reading of GCS record file... ===") + example_data = create_examples(FLAGS.num_examples, 5) + with tf.python_io.TFRecordWriter(FLAGS.gcs_bucket_url) as hf: + for e in example_data: + hf.write(e.SerializeToString()) + + print("Data written to: %s" % FLAGS.gcs_bucket_url) + + # Verify that reading from the tfrecord file works and that + # tf_record_iterator works. + record_iter = tf.python_io.tf_record_iterator(FLAGS.gcs_bucket_url) + read_count = 0 + for _ in record_iter: + read_count += 1 + print("Read %d records using tf_record_iterator" % read_count) + + if read_count != FLAGS.num_examples: + print("FAIL: The number of records read from tf_record_iterator (%d) " + "differs from the expected number (%d)" % (read_count, + FLAGS.num_examples)) + sys.exit(1) + + # Verify that running the read op in a session works. + print("\n=== Testing TFRecordReader.read op in a session... ===") + with tf.Graph().as_default() as _: + filename_queue = tf.train.string_input_producer([FLAGS.gcs_bucket_url], + num_epochs=1) + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + tf.train.start_queue_runners() + index = 0 + for _ in range(FLAGS.num_examples): + print("Read record: %d" % index) + sess.run(serialized_example) + index += 1 + + # Reading one more record should trigger an exception. + try: + sess.run(serialized_example) + print("FAIL: Failed to catch the expected OutOfRangeError while " + "reading one more record than is available") + sys.exit(1) + except tf.errors.OutOfRangeError: + print("Successfully caught the expected OutOfRangeError while " + "reading one more record than is available") + + create_dir_test() + create_object_test() + +if __name__ == "__main__": + tf.app.run(main) diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh new file mode 100755 index 0000000000000000000000000000000000000000..6553ba5e3093c26d3c95f40216cd3922a1fb9e4e --- /dev/null +++ b/tensorflow/tools/integration_tests/gcs_smoke_test/setup.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +GCS_NUMBER=$(cat /dev/urandom | tr -dc 'A-F0-9' | fold -w 8 | head -n 1) +GCS_PATH="$1"/"$GCS_NUMBER".tfrecord + +echo "gcs_path=$GCS_PATH" > "$_SETUP_OUTPUT" +touch "$_SETUP_DONE" diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh new file mode 100755 index 0000000000000000000000000000000000000000..852486d1677ec597fe56111ffb0e470c333c1cd7 --- /dev/null +++ b/tensorflow/tools/integration_tests/gcs_smoke_test/teardown.sh @@ -0,0 +1,26 @@ +#!/bin/bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +GSUTIL_BIN="/var/gcloud/google-cloud-sdk/bin/gsutil" + +echo "Got teardown argument $1" + +if "${GSUTIL_BIN}" rm "$1" +then + echo "Cleaned up new tfrecord file in GCS: '$1'" +else + echo "FAIL: Unable to clean up new tfrecord file in GCS: '$1'" + exit 1 +fi diff --git a/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh b/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh new file mode 100755 index 0000000000000000000000000000000000000000..ef29dee3462c21d6318a6fb7e7e658961f0d88dd --- /dev/null +++ b/tensorflow/tools/integration_tests/gcs_smoke_test/test_wrapper.sh @@ -0,0 +1,21 @@ +# This is a python2 only test. +#!/bin/bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Test Tensorflow package installation. +/usr/local/bin/pip install --user tf-nightly + +# Test Tensorflow interaction with GCS. +python tensorflow/tools/integration_test/gcs_smoke_test/gcs_smoke.py "$@" diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD index 614457e8996491a60d4a7df213180117bce321ad..3fbdb5cacd1fd0039deaae5ac330b6c2ca006a68 100644 --- a/tensorflow/tools/lib_package/BUILD +++ b/tensorflow/tools/lib_package/BUILD @@ -27,6 +27,7 @@ pkg_tar( ":cheaders", ":clib", ":clicenses", + ":eager_cheaders", ], ) @@ -57,7 +58,6 @@ pkg_tar( name = "cheaders", files = [ "//tensorflow/c:headers", - "//tensorflow/c/eager:headers", ], package_dir = "include/tensorflow/c", # Mark as "manual" till @@ -68,6 +68,20 @@ pkg_tar( tags = ["manual"], ) +pkg_tar( + name = "eager_cheaders", + files = [ + "//tensorflow/c/eager:headers", + ], + package_dir = "include/tensorflow/c/eager", + # Mark as "manual" till + # https://github.com/bazelbuild/bazel/issues/2352 + # and https://github.com/bazelbuild/bazel/issues/1580 + # are resolved, otherwise these rules break when built + # with Python 3. + tags = ["manual"], +) + pkg_tar( name = "clib", files = ["//tensorflow:libtensorflow.so"], diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index 02a34518c04a6ef738e46002ae4d07c801cc58f8..fb6eaa4faa28b4f6b17e1774907c0c9ff58d6ada 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -150,7 +150,7 @@ sh_binary( "//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test", "//tensorflow/contrib/data/python/ops:contrib_op_loader", "//tensorflow/contrib/eager/python/examples:examples_pip", - "//tensorflow/contrib/eager/python:checkpointable", + "//tensorflow/contrib/eager/python:checkpointable_utils", "//tensorflow/contrib/eager/python:evaluator", "//tensorflow/contrib/gan:gan", "//tensorflow/contrib/graph_editor:graph_editor_pip", diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py index e4ca974e1b7d1e86d4b64d0035df389b5fffe3c2..4b6f123daa7b528173234a2bffd30ead2aa9fc0e 100644 --- a/tensorflow/tools/pip_package/setup.py +++ b/tensorflow/tools/pip_package/setup.py @@ -29,17 +29,17 @@ from setuptools.dist import Distribution # This version string is semver compatible, but incompatible with pip. # For pip, we will remove all '-' characters from this string, and use the # result for pip. -_VERSION = '1.6.0-rc0' +_VERSION = '1.6.0-rc1' REQUIRED_PACKAGES = [ 'absl-py >= 0.1.6', 'astor >= 0.6.0', 'gast >= 0.2.0', 'grpcio >= 1.8.6', - 'numpy >= 1.12.1', + 'numpy >= 1.13.3', 'six >= 1.10.0', 'protobuf >= 3.4.0', - 'tensorflow-tensorboard >= 1.5.0, < 1.6.0', + 'tensorboard >= 1.6.0, < 1.7.0', 'termcolor >= 1.1.0', ] diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index d0f9a8925faf93d0c3f851da7262115105c6517d..9009f08163cfada60c95e6fe07b54f84d6dd96c2 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -5,6 +5,7 @@ load("//third_party/tensorrt:tensorrt_configure.bzl", "tensorrt_configure") load("//third_party/mkl:build_defs.bzl", "mkl_repository") load("//third_party/git:git_configure.bzl", "git_configure") load("//third_party/py:python_configure.bzl", "python_configure") + load("//third_party/sycl:sycl_configure.bzl", "sycl_configure") load("//third_party/toolchains/clang6:repo.bzl", "clang6_configure") load("//third_party/toolchains/cpus/arm:arm_compiler_configure.bzl", "arm_compiler_configure") @@ -126,6 +127,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): sha256 = "0cadb31a35b514bf2dfd6b5d38205da94ef326ec6908fc3fd7c269948467214f", strip_prefix = "eigen-eigen-2355b229ea4c", build_file = str(Label("//third_party:eigen.BUILD")), + patch_file = str(Label("//third_party:eigen_fix_cuda_compilation.patch")) ) tf_http_archive( @@ -179,11 +181,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "gemmlowp", urls = [ - "https://mirror.bazel.build/github.com/google/gemmlowp/archive/d4d1e29a62192d8defdc057b913ef36ca582ac98.zip", - "https://github.com/google/gemmlowp/archive/d4d1e29a62192d8defdc057b913ef36ca582ac98.zip", + "https://mirror.bazel.build/github.com/google/gemmlowp/archive/7c7c744640ddc3d0af18fb245b4d23228813a71b.zip", + "https://github.com/google/gemmlowp/archive/7c7c744640ddc3d0af18fb245b4d23228813a71b.zip", ], - sha256 = "e2bee7afd3c43028f23dd0d7f85ddd8b21aaf79c572b658e56164ef502b2b9c7", - strip_prefix = "gemmlowp-d4d1e29a62192d8defdc057b913ef36ca582ac98", + sha256 = "b852cc90259a7357c8a323f108f2cec6e85979fc3b18b5590b99e0130044b2cf", + strip_prefix = "gemmlowp-7c7c744640ddc3d0af18fb245b4d23228813a71b", ) tf_http_archive( @@ -213,6 +215,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""): urls = [ "https://mirror.bazel.build/www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", "http://pkgs.fedoraproject.org/repo/pkgs/nasm/nasm-2.12.02.tar.bz2/d15843c3fb7db39af80571ee27ec6fad/nasm-2.12.02.tar.bz2", + "http://www.nasm.us/pub/nasm/releasebuilds/2.12.02/nasm-2.12.02.tar.bz2", ], sha256 = "00b0891c678c065446ca59bcee64719d0096d54d6886e6e472aeee2e170ae324", strip_prefix = "nasm-2.12.02", @@ -473,11 +476,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "llvm", urls = [ - "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/562d4e516ab92302b34b7f4c8833455699bb48de.tar.gz", - "https://github.com/llvm-mirror/llvm/archive/562d4e516ab92302b34b7f4c8833455699bb48de.tar.gz", + "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/8f7bcdf3c65b9a47e35653d525135beb18f3ac25.tar.gz", + "https://github.com/llvm-mirror/llvm/archive/8f7bcdf3c65b9a47e35653d525135beb18f3ac25.tar.gz", ], - sha256 = "cd041cda90f2e29fd3053f3faca182ad7ed871045d789c339d0f7c7d25310ef2", - strip_prefix = "llvm-562d4e516ab92302b34b7f4c8833455699bb48de", + sha256 = "63d4da54dc7bc9a79e2ad266d230f4f759520cccb344a2dd49c2c6383ab75285", + strip_prefix = "llvm-8f7bcdf3c65b9a47e35653d525135beb18f3ac25", build_file = str(Label("//third_party/llvm:llvm.BUILD")), ) @@ -664,15 +667,12 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "cub_archive", urls = [ - "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip", - "https://github.com/NVlabs/cub/archive/1.7.4.zip", + "https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.8.0.zip", + "https://github.com/NVlabs/cub/archive/1.8.0.zip", ], - sha256 = "20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31", - strip_prefix = "cub-1.7.4", + sha256 = "6bfa06ab52a650ae7ee6963143a0bbc667d6504822cbd9670369b598f18c58c3", + strip_prefix = "cub-1.8.0", build_file = str(Label("//third_party:cub.BUILD")), - # TODO: remove the patch when upstream fix is accepted and released. - # PR with a fix: https://github.com/NVlabs/cub/pull/125 - patch_file = str(Label("//third_party/cub:fix_compilation_in_clang.patch")), ) tf_http_archive( @@ -690,13 +690,23 @@ def tf_workspace(path_prefix="", tf_repo_name=""): tf_http_archive( name = "bazel_toolchains", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/f3b09700fae5d7b6e659d7cefe0dcc6e8498504c.tar.gz", - "https://github.com/bazelbuild/bazel-toolchains/archive/f3b09700fae5d7b6e659d7cefe0dcc6e8498504c.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz", + "https://github.com/bazelbuild/bazel-toolchains/archive/44200e0c026d86c53470d107b3697a3e46469c43.tar.gz", ], - sha256 = "ed829b5eea8af1f405f4cc3d6ecfc3b1365bb7843171036030a31b5127002311", - strip_prefix = "bazel-toolchains-f3b09700fae5d7b6e659d7cefe0dcc6e8498504c", + strip_prefix = "bazel-toolchains-44200e0c026d86c53470d107b3697a3e46469c43", + sha256 = "699b55a6916c687f4b7dc092dbbf5f64672cde0dc965f79717735ec4e5416556", ) + tf_http_archive( + name = "rbe_integration_test", + urls = [ + "http://mirror.bazel.build/github.com/google/rbe-integration-test/archive/78a6194c7dda200b9522cf07707e3bc695804d1e.tar.gz", + "https://github.com/google/rbe-integration-test/archive/78a6194c7dda200b9522cf07707e3bc695804d1e.tar.gz", + ], + sha256 = "66d93b3919a165d486c31f5290d312abe9fda2685242f812c110653c124e1db4", + strip_prefix = "rbe-integration-test-78a6194c7dda200b9522cf07707e3bc695804d1e", + ) + tf_http_archive( name = "arm_neon_2_x86_sse", sha256 = "c8d90aa4357f8079d427e87a6f4c493da1fa4140aee926c05902d7ec1533d9a5", diff --git a/third_party/cub/BUILD b/third_party/cub/BUILD deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/third_party/cub/fix_compilation_in_clang.patch b/third_party/cub/fix_compilation_in_clang.patch deleted file mode 100644 index 384e674f2012b2b3ea59c5c1bd205873baa8cf18..0000000000000000000000000000000000000000 --- a/third_party/cub/fix_compilation_in_clang.patch +++ /dev/null @@ -1,23 +0,0 @@ -From 565b77f7c82048871a4d5e3e506dc663d53cd469 Mon Sep 17 00:00:00 2001 -From: Ilya Biryukov -Date: Fri, 26 Jan 2018 18:46:06 +0100 -Subject: [PATCH] Added missing 'template' keyword. - -To unbreak compilation with clang. ---- - cub/device/dispatch/dispatch_radix_sort.cuh | 2 +- - 1 file changed, 1 insertion(+), 1 deletion(-) - -diff --git a/cub/device/dispatch/dispatch_radix_sort.cuh b/cub/device/dispatch/dispatch_radix_sort.cuh -index 7fbc621f..f622e212 100644 ---- a/cub/device/dispatch/dispatch_radix_sort.cuh -+++ b/cub/device/dispatch/dispatch_radix_sort.cuh -@@ -104,7 +104,7 @@ __global__ void DeviceRadixSortUpsweepKernel( - CTA_SYNC(); - - // Write out digit counts (striped) -- upsweep.ExtractCounts(d_spine, gridDim.x, blockIdx.x); -+ upsweep.template ExtractCounts(d_spine, gridDim.x, blockIdx.x); - } - - diff --git a/third_party/eigen_fix_cuda_compilation.patch b/third_party/eigen_fix_cuda_compilation.patch new file mode 100644 index 0000000000000000000000000000000000000000..b921a7c31d5c96c79cd3033b13c60a8f7e63ba75 --- /dev/null +++ b/third_party/eigen_fix_cuda_compilation.patch @@ -0,0 +1,38 @@ +diff --git a/Eigen/src/Core/ProductEvaluators.h b/Eigen/src/Core/ProductEvaluators.h +--- a/Eigen/src/Core/ProductEvaluators.h ++++ b/Eigen/src/Core/ProductEvaluators.h +@@ -137,7 +137,7 @@ struct Assignment::type> + { + typedef Product SrcXprType; +- static EIGEN_STRONG_INLINE ++ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE + void run(DstXprType &dst, const SrcXprType &src, const internal::assign_op &) + { + Index dstRows = src.rows(); +@@ -390,7 +390,7 @@ struct generic_product_impl::Scalar Scalar; + + template +- static EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) ++ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void evalTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) + { + // Same as: dst.noalias() = lhs.lazyProduct(rhs); + // but easier on the compiler side +@@ -398,14 +398,14 @@ struct generic_product_impl +- static EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) ++ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void addTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) + { + // dst.noalias() += lhs.lazyProduct(rhs); + call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::add_assign_op()); + } + + template +- static EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) ++ static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void subTo(Dst& dst, const Lhs& lhs, const Rhs& rhs) + { + // dst.noalias() -= lhs.lazyProduct(rhs); + call_assignment_no_alias(dst, lhs.lazyProduct(rhs), internal::sub_assign_op()); diff --git a/third_party/examples/eager/spinn/README.md b/third_party/examples/eager/spinn/README.md index 6bd3d53e56d01e15491ecd383dcc763a19d75b88..7f477d19208257474d0481ca04c04679f589b751 100644 --- a/third_party/examples/eager/spinn/README.md +++ b/third_party/examples/eager/spinn/README.md @@ -66,3 +66,44 @@ Other eager execution examples can be found under [tensorflow/contrib/eager/pyth ```bash tensorboard --logdir /tmp/spinn-logs ``` + +- After training, you may use the model to perform inference on input data in + the SNLI data format. The premise and hypotheses sentences are specified with + the command-line flags `--inference_premise` and `--inference_hypothesis`, + respecitvely. Each sentence should include the words, as well as parentheses + representing a binary parsing of the sentence. The words and parentheses + should all be separated by spaces. For instance, + + ```bash + python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ + --inference_premise '( ( The dog ) ( ( is running ) . ) )' \ + --inference_hypothesis '( ( The dog ) ( moves . ) )' + ``` + + which will generate an output like the following, due to the semantic + consistency of the two sentences. + + ```none + Inference logits: + entailment: 1.101249 (winner) + contradiction: -2.374171 + neutral: -0.296733 + ``` + + By contrast, the following sentence pair: + + ```bash + python spinn.py --data_root /tmp/spinn-data --logdir /tmp/spinn-logs \ + --inference_premise '( ( The dog ) ( ( is running ) . ) )' \ + --inference_hypothesis '( ( The dog ) ( rests . ) )' + ``` + + will give you an output like the following, due to the semantic + contradiction of the two sentences. + + ```none + Inference logits: + entailment: -1.070098 + contradiction: 2.798695 (winner) + neutral: -1.402287 + ``` diff --git a/third_party/examples/eager/spinn/spinn.py b/third_party/examples/eager/spinn/spinn.py index a2fa18eeb1077c8a1ccd4ab0bcd178f952e17270..8a1c7db2ea14365be53a796a79fce77900e668e1 100644 --- a/third_party/examples/eager/spinn/spinn.py +++ b/third_party/examples/eager/spinn/spinn.py @@ -471,6 +471,15 @@ class SNLIClassifierTrainer(object): def learning_rate(self): return self._learning_rate + @property + def model(self): + return self._model + + @property + def variables(self): + return (self._model.variables + [self.learning_rate] + + self._optimizer.variables()) + def _batch_n_correct(logits, label): """Calculate number of correct predictions in a batch. @@ -488,13 +497,12 @@ def _batch_n_correct(logits, label): tf.argmax(logits, axis=1), label)), tf.float32)).numpy() -def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): +def _evaluate_on_dataset(snli_data, batch_size, trainer, use_gpu): """Run evaluation on a dataset. Args: snli_data: The `data.SnliData` to use in this evaluation. batch_size: The batch size to use during this evaluation. - model: An instance of `SNLIClassifier` to evaluate. trainer: An instance of `SNLIClassifierTrainer to use for this evaluation. use_gpu: Whether GPU is being used. @@ -509,7 +517,7 @@ def _evaluate_on_dataset(snli_data, batch_size, model, trainer, use_gpu): snli_data, batch_size): if use_gpu: label, prem, hypo = label.gpu(), prem.gpu(), hypo.gpu() - logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + logits = trainer.model(prem, prem_trans, hypo, hypo_trans, training=False) loss_val = trainer.loss(label, logits) batch_size = tf.shape(label)[0] mean_loss(loss_val, weights=batch_size.gpu() if use_gpu else batch_size) @@ -536,13 +544,19 @@ def _get_dataset_iterator(snli_data, batch_size): return tfe.Iterator(dataset) -def train_spinn(embed, train_data, dev_data, test_data, config): - """Train a SPINN model. +def train_or_infer_spinn(embed, + word2index, + train_data, + dev_data, + test_data, + config): + """Perform Training or Inference on a SPINN model. Args: embed: The embedding matrix as a float32 numpy array with shape [vocabulary_size, word_vector_len]. word_vector_len is the length of a word embedding vector. + word2index: A `dict` mapping word to word index. train_data: An instance of `data.SnliData`, for the train split. dev_data: Same as above, for the dev split. test_data: Same as above, for the test split. @@ -550,13 +564,35 @@ def train_spinn(embed, train_data, dev_data, test_data, config): details. Returns: - 1. Final loss value on the test split. - 2. Final fraction of correct classifications on the test split. + If `config.inference_premise ` and `config.inference_hypothesis` are not + `None`, i.e., inference mode: the logits for the possible labels of the + SNLI data set, as a `Tensor` of three floats. + else: + The trainer object. + Raises: + ValueError: if only one of config.inference_premise and + config.inference_hypothesis is specified. """ + # TODO(cais): Refactor this function into separate one for training and + # inference. use_gpu = tfe.num_gpus() > 0 and not config.force_cpu device = "gpu:0" if use_gpu else "cpu:0" print("Using device: %s" % device) + if ((config.inference_premise and not config.inference_hypothesis) or + (not config.inference_premise and config.inference_hypothesis)): + raise ValueError( + "--inference_premise and --inference_hypothesis must be both " + "specified or both unspecified, but only one is specified.") + + if config.inference_premise: + # Inference mode. + inference_sentence_pair = [ + data.encode_sentence(config.inference_premise, word2index), + data.encode_sentence(config.inference_hypothesis, word2index)] + else: + inference_sentence_pair = None + log_header = ( " Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss" " Accuracy Dev/Accuracy") @@ -569,16 +605,36 @@ def train_spinn(embed, train_data, dev_data, test_data, config): summary_writer = tf.contrib.summary.create_file_writer( config.logdir, flush_millis=10000) - train_len = train_data.num_batches(config.batch_size) + with tf.device(device), \ - tfe.restore_variables_on_create( - tf.train.latest_checkpoint(config.logdir)), \ summary_writer.as_default(), \ tf.contrib.summary.always_record_summaries(): - model = SNLIClassifier(config, embed) - global_step = tf.train.get_or_create_global_step() - trainer = SNLIClassifierTrainer(model, config.lr) - + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + model = SNLIClassifier(config, embed) + global_step = tf.train.get_or_create_global_step() + trainer = SNLIClassifierTrainer(model, config.lr) + + if inference_sentence_pair: + # Inference mode. + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + prem, prem_trans = inference_sentence_pair[0] + hypo, hypo_trans = inference_sentence_pair[1] + hypo_trans = inference_sentence_pair[1][1] + inference_logits = model( # pylint: disable=not-callable + tf.constant(prem), tf.constant(prem_trans), + tf.constant(hypo), tf.constant(hypo_trans), training=False) + inference_logits = inference_logits[0][1:] + max_index = tf.argmax(inference_logits) + print("\nInference logits:") + for i, (label, logit) in enumerate( + zip(data.POSSIBLE_LABELS, inference_logits)): + winner_tag = " (winner)" if max_index == i else "" + print(" {0:<16}{1:.6f}{2}".format(label + ":", logit, winner_tag)) + return inference_logits + + train_len = train_data.num_batches(config.batch_size) start = time.time() iterations = 0 mean_loss = tfe.metrics.Mean() @@ -594,23 +650,24 @@ def train_spinn(embed, train_data, dev_data, test_data, config): # remain on CPU. Same in _evaluate_on_dataset(). iterations += 1 - batch_train_loss, batch_train_logits = trainer.train_batch( - label, prem, prem_trans, hypo, hypo_trans) + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(config.logdir)): + batch_train_loss, batch_train_logits = trainer.train_batch( + label, prem, prem_trans, hypo, hypo_trans) batch_size = tf.shape(label)[0] mean_loss(batch_train_loss.numpy(), weights=batch_size.gpu() if use_gpu else batch_size) accuracy(tf.argmax(batch_train_logits, axis=1), label) if iterations % config.save_every == 0: - all_variables = ( - model.variables + [trainer.learning_rate] + [global_step]) + all_variables = trainer.variables + [global_step] saver = tfe.Saver(all_variables) saver.save(os.path.join(config.logdir, "ckpt"), global_step=global_step) if iterations % config.dev_every == 0: dev_loss, dev_frac_correct = _evaluate_on_dataset( - dev_data, config.batch_size, model, trainer, use_gpu) + dev_data, config.batch_size, trainer, use_gpu) print(dev_log_template.format( time.time() - start, epoch, iterations, 1 + batch_idx, train_len, @@ -638,10 +695,12 @@ def train_spinn(embed, train_data, dev_data, test_data, config): trainer.decay_learning_rate(config.lr_decay_by) test_loss, test_frac_correct = _evaluate_on_dataset( - test_data, config.batch_size, model, trainer, use_gpu) + test_data, config.batch_size, trainer, use_gpu) print("Final test loss: %g; accuracy: %g%%" % (test_loss, test_frac_correct * 100.0)) + return trainer + def main(_): config = FLAGS @@ -650,18 +709,24 @@ def main(_): vocab = data.load_vocabulary(FLAGS.data_root) word2index, embed = data.load_word_vectors(FLAGS.data_root, vocab) - print("Loading train, dev and test data...") - train_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - dev_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - test_data = data.SnliData( - os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), - word2index, sentence_len_limit=FLAGS.sentence_len_limit) - - train_spinn(embed, train_data, dev_data, test_data, config) + if not (config.inference_premise or config.inference_hypothesis): + print("Loading train, dev and test data...") + train_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_train.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + dev_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_dev.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + test_data = data.SnliData( + os.path.join(FLAGS.data_root, "snli/snli_1.0/snli_1.0_test.txt"), + word2index, sentence_len_limit=FLAGS.sentence_len_limit) + else: + train_data = None + dev_data = None + test_data = None + + train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) if __name__ == "__main__": @@ -678,6 +743,15 @@ if __name__ == "__main__": parser.add_argument("--logdir", type=str, default="/tmp/spinn-logs", help="Directory in which summaries will be written for " "TensorBoard.") + parser.add_argument("--inference_premise", type=str, default=None, + help="Premise sentence for inference. Must be " + "accompanied by --inference_hypothesis. If specified, " + "will override all training parameters and perform " + "inference.") + parser.add_argument("--inference_hypothesis", type=str, default=None, + help="Hypothesis sentence for inference. Must be " + "accompanied by --inference_premise. If specified, will " + "override all training parameters and perform inference.") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs to train.") parser.add_argument("--batch_size", type=int, default=128,

  • Version:CPU/GPU:Python Version:Compiler:Build Tools:cuDNN:CUDA:
    tensorflow-1.6.0rc0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.6.0rc0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.6.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.6.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.5.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A
    tensorflow_gpu-1.5.0GPU3.5-3.6MSVC 2015 update 3Cmake v3.6.379
    tensorflow-1.4.0CPU3.5-3.6MSVC 2015 update 3Cmake v3.6.3N/AN/A